diff --git a/AdaptiveWingLoss/__pycache__/aux.cpython-310.pyc b/AdaptiveWingLoss/__pycache__/aux.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e28039557364b7f54e55322f5bb07fd52e1c92b3
Binary files /dev/null and b/AdaptiveWingLoss/__pycache__/aux.cpython-310.pyc differ
diff --git a/AdaptiveWingLoss/aux.py b/AdaptiveWingLoss/aux.py
new file mode 100644
index 0000000000000000000000000000000000000000..f566bf532405bdaeb350e7b50dcffb4d328835c3
--- /dev/null
+++ b/AdaptiveWingLoss/aux.py
@@ -0,0 +1,4 @@
+def detect_landmarks(inputs, model_ft):
+ outputs, _ = model_ft(inputs)
+ pred_heatmap = outputs[-1][:, :-1, :, :]
+ return pred_heatmap[:, 96, :, :], pred_heatmap[:, 97, :, :]
diff --git a/AdaptiveWingLoss/core/__init__.py b/AdaptiveWingLoss/core/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/AdaptiveWingLoss/core/__pycache__/__init__.cpython-310.pyc b/AdaptiveWingLoss/core/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5b8d178164b7898d4c7e6b6acbadd1db32e19e41
Binary files /dev/null and b/AdaptiveWingLoss/core/__pycache__/__init__.cpython-310.pyc differ
diff --git a/AdaptiveWingLoss/core/__pycache__/coord_conv.cpython-310.pyc b/AdaptiveWingLoss/core/__pycache__/coord_conv.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d57e2e63717668e6923664247a43c5936336a5cf
Binary files /dev/null and b/AdaptiveWingLoss/core/__pycache__/coord_conv.cpython-310.pyc differ
diff --git a/AdaptiveWingLoss/core/__pycache__/models.cpython-310.pyc b/AdaptiveWingLoss/core/__pycache__/models.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4e8b11af1545643eada6b1248255f34a014f0d60
Binary files /dev/null and b/AdaptiveWingLoss/core/__pycache__/models.cpython-310.pyc differ
diff --git a/AdaptiveWingLoss/core/coord_conv.py b/AdaptiveWingLoss/core/coord_conv.py
new file mode 100755
index 0000000000000000000000000000000000000000..37cae2d71b1c0211534d64a8da4c28530b13efe4
--- /dev/null
+++ b/AdaptiveWingLoss/core/coord_conv.py
@@ -0,0 +1,143 @@
+import torch
+import torch.nn as nn
+
+
+class AddCoordsTh(nn.Module):
+ def __init__(self, x_dim=64, y_dim=64, with_r=False, with_boundary=False):
+ super(AddCoordsTh, self).__init__()
+ self.x_dim = x_dim
+ self.y_dim = y_dim
+ self.with_r = with_r
+ self.with_boundary = with_boundary
+
+ def forward(self, input_tensor, heatmap=None):
+ """
+ input_tensor: (batch, c, x_dim, y_dim)
+ """
+ batch_size_tensor = input_tensor.shape[0]
+
+ xx_ones = torch.ones([1, self.y_dim], dtype=torch.int32).to(input_tensor.device)
+ xx_ones = xx_ones.unsqueeze(-1)
+
+ xx_range = torch.arange(self.x_dim, dtype=torch.int32).unsqueeze(0).to(input_tensor.device)
+ xx_range = xx_range.unsqueeze(1)
+
+ xx_channel = torch.matmul(xx_ones.float(), xx_range.float())
+ xx_channel = xx_channel.unsqueeze(-1)
+
+ yy_ones = torch.ones([1, self.x_dim], dtype=torch.int32).to(input_tensor.device)
+ yy_ones = yy_ones.unsqueeze(1)
+
+ yy_range = torch.arange(self.y_dim, dtype=torch.int32).unsqueeze(0).to(input_tensor.device)
+ yy_range = yy_range.unsqueeze(-1)
+
+ yy_channel = torch.matmul(yy_range.float(), yy_ones.float())
+ yy_channel = yy_channel.unsqueeze(-1)
+
+ xx_channel = xx_channel.permute(0, 3, 2, 1)
+ yy_channel = yy_channel.permute(0, 3, 2, 1)
+
+ xx_channel = xx_channel / (self.x_dim - 1)
+ yy_channel = yy_channel / (self.y_dim - 1)
+
+ xx_channel = xx_channel * 2 - 1
+ yy_channel = yy_channel * 2 - 1
+
+ xx_channel = xx_channel.repeat(batch_size_tensor, 1, 1, 1)
+ yy_channel = yy_channel.repeat(batch_size_tensor, 1, 1, 1)
+
+ if self.with_boundary and type(heatmap) != type(None):
+ boundary_channel = torch.clamp(heatmap[:, -1:, :, :], 0.0, 1.0)
+
+ zero_tensor = torch.zeros_like(xx_channel)
+ xx_boundary_channel = torch.where(boundary_channel > 0.05, xx_channel, zero_tensor)
+ yy_boundary_channel = torch.where(boundary_channel > 0.05, yy_channel, zero_tensor)
+ if self.with_boundary and type(heatmap) != type(None):
+ xx_boundary_channel = xx_boundary_channel.to(input_tensor.device)
+ yy_boundary_channel = yy_boundary_channel.to(input_tensor.device)
+ ret = torch.cat([input_tensor, xx_channel, yy_channel], dim=1)
+
+ if self.with_r:
+ rr = torch.sqrt(torch.pow(xx_channel, 2) + torch.pow(yy_channel, 2))
+ rr = rr / torch.max(rr)
+ ret = torch.cat([ret, rr], dim=1)
+
+ if self.with_boundary and type(heatmap) != type(None):
+ ret = torch.cat([ret, xx_boundary_channel, yy_boundary_channel], dim=1)
+ return ret
+
+
+class CoordConvTh(nn.Module):
+ """CoordConv layer as in the paper."""
+
+ def __init__(self, x_dim, y_dim, with_r, with_boundary, in_channels, first_one=False, *args, **kwargs):
+ super(CoordConvTh, self).__init__()
+ self.addcoords = AddCoordsTh(x_dim=x_dim, y_dim=y_dim, with_r=with_r, with_boundary=with_boundary)
+ in_channels += 2
+ if with_r:
+ in_channels += 1
+ if with_boundary and not first_one:
+ in_channels += 2
+ self.conv = nn.Conv2d(in_channels=in_channels, *args, **kwargs)
+
+ def forward(self, input_tensor, heatmap=None):
+ ret = self.addcoords(input_tensor, heatmap)
+ last_channel = ret[:, -2:, :, :]
+ ret = self.conv(ret)
+ return ret, last_channel
+
+
+"""
+An alternative implementation for PyTorch with auto-infering the x-y dimensions.
+"""
+
+
+class AddCoords(nn.Module):
+ def __init__(self, with_r=False):
+ super().__init__()
+ self.with_r = with_r
+
+ def forward(self, input_tensor):
+ """
+ Args:
+ input_tensor: shape(batch, channel, x_dim, y_dim)
+ """
+ batch_size, _, x_dim, y_dim = input_tensor.size()
+
+ xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1)
+ yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2)
+
+ xx_channel = xx_channel / (x_dim - 1)
+ yy_channel = yy_channel / (y_dim - 1)
+
+ xx_channel = xx_channel * 2 - 1
+ yy_channel = yy_channel * 2 - 1
+
+ xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
+ yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
+
+ if input_tensor.is_cuda:
+ xx_channel = xx_channel.to(input_tensor.device)
+ yy_channel = yy_channel.to(input_tensor.device)
+
+ ret = torch.cat([input_tensor, xx_channel.type_as(input_tensor), yy_channel.type_as(input_tensor)], dim=1)
+
+ if self.with_r:
+ rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2))
+ if input_tensor.is_cuda:
+ rr = rr.to(input_tensor.device)
+ ret = torch.cat([ret, rr], dim=1)
+
+ return ret
+
+
+class CoordConv(nn.Module):
+ def __init__(self, in_channels, out_channels, with_r=False, **kwargs):
+ super().__init__()
+ self.addcoords = AddCoords(with_r=with_r)
+ self.conv = nn.Conv2d(in_channels + 2, out_channels, **kwargs)
+
+ def forward(self, x):
+ ret = self.addcoords(x)
+ ret = self.conv(ret)
+ return ret
diff --git a/AdaptiveWingLoss/core/dataloader.py b/AdaptiveWingLoss/core/dataloader.py
new file mode 100755
index 0000000000000000000000000000000000000000..b901540a29e1a72abe2e88c4d052791a9dc4ec36
--- /dev/null
+++ b/AdaptiveWingLoss/core/dataloader.py
@@ -0,0 +1,350 @@
+import copy
+import glob
+import math
+import os
+import random
+import sys
+
+import cv2
+import matplotlib.pyplot as plt
+import numpy as np
+import scipy.io as sio
+import torch
+from imgaug import augmenters as iaa
+from PIL import Image
+from scipy import interpolate
+from skimage import io
+from skimage import transform as ski_transform
+from skimage.color import rgb2gray
+from torch.utils.data import DataLoader
+from torch.utils.data import Dataset
+from torchvision import transforms
+from torchvision import utils
+from torchvision.transforms import Compose
+from torchvision.transforms import Lambda
+from torchvision.transforms.functional import adjust_brightness
+from torchvision.transforms.functional import adjust_contrast
+from torchvision.transforms.functional import adjust_hue
+from torchvision.transforms.functional import adjust_saturation
+
+from utils.utils import cv_crop
+from utils.utils import cv_rotate
+from utils.utils import draw_gaussian
+from utils.utils import fig2data
+from utils.utils import generate_weight_map
+from utils.utils import power_transform
+from utils.utils import shuffle_lr
+from utils.utils import transform
+
+
+class AddBoundary(object):
+ def __init__(self, num_landmarks=68):
+ self.num_landmarks = num_landmarks
+
+ def __call__(self, sample):
+ landmarks_64 = np.floor(sample["landmarks"] / 4.0)
+ if self.num_landmarks == 68:
+ boundaries = {}
+ boundaries["cheek"] = landmarks_64[0:17]
+ boundaries["left_eyebrow"] = landmarks_64[17:22]
+ boundaries["right_eyebrow"] = landmarks_64[22:27]
+ boundaries["uper_left_eyelid"] = landmarks_64[36:40]
+ boundaries["lower_left_eyelid"] = np.array([landmarks_64[i] for i in [36, 41, 40, 39]])
+ boundaries["upper_right_eyelid"] = landmarks_64[42:46]
+ boundaries["lower_right_eyelid"] = np.array([landmarks_64[i] for i in [42, 47, 46, 45]])
+ boundaries["noise"] = landmarks_64[27:31]
+ boundaries["noise_bot"] = landmarks_64[31:36]
+ boundaries["upper_outer_lip"] = landmarks_64[48:55]
+ boundaries["upper_inner_lip"] = np.array([landmarks_64[i] for i in [60, 61, 62, 63, 64]])
+ boundaries["lower_outer_lip"] = np.array([landmarks_64[i] for i in [48, 59, 58, 57, 56, 55, 54]])
+ boundaries["lower_inner_lip"] = np.array([landmarks_64[i] for i in [60, 67, 66, 65, 64]])
+ elif self.num_landmarks == 98:
+ boundaries = {}
+ boundaries["cheek"] = landmarks_64[0:33]
+ boundaries["left_eyebrow"] = landmarks_64[33:38]
+ boundaries["right_eyebrow"] = landmarks_64[42:47]
+ boundaries["uper_left_eyelid"] = landmarks_64[60:65]
+ boundaries["lower_left_eyelid"] = np.array([landmarks_64[i] for i in [60, 67, 66, 65, 64]])
+ boundaries["upper_right_eyelid"] = landmarks_64[68:73]
+ boundaries["lower_right_eyelid"] = np.array([landmarks_64[i] for i in [68, 75, 74, 73, 72]])
+ boundaries["noise"] = landmarks_64[51:55]
+ boundaries["noise_bot"] = landmarks_64[55:60]
+ boundaries["upper_outer_lip"] = landmarks_64[76:83]
+ boundaries["upper_inner_lip"] = np.array([landmarks_64[i] for i in [88, 89, 90, 91, 92]])
+ boundaries["lower_outer_lip"] = np.array([landmarks_64[i] for i in [76, 87, 86, 85, 84, 83, 82]])
+ boundaries["lower_inner_lip"] = np.array([landmarks_64[i] for i in [88, 95, 94, 93, 92]])
+ elif self.num_landmarks == 19:
+ boundaries = {}
+ boundaries["left_eyebrow"] = landmarks_64[0:3]
+ boundaries["right_eyebrow"] = landmarks_64[3:5]
+ boundaries["left_eye"] = landmarks_64[6:9]
+ boundaries["right_eye"] = landmarks_64[9:12]
+ boundaries["noise"] = landmarks_64[12:15]
+
+ elif self.num_landmarks == 29:
+ boundaries = {}
+ boundaries["upper_left_eyebrow"] = np.stack([landmarks_64[0], landmarks_64[4], landmarks_64[2]], axis=0)
+ boundaries["lower_left_eyebrow"] = np.stack([landmarks_64[0], landmarks_64[5], landmarks_64[2]], axis=0)
+ boundaries["upper_right_eyebrow"] = np.stack([landmarks_64[1], landmarks_64[6], landmarks_64[3]], axis=0)
+ boundaries["lower_right_eyebrow"] = np.stack([landmarks_64[1], landmarks_64[7], landmarks_64[3]], axis=0)
+ boundaries["upper_left_eye"] = np.stack([landmarks_64[8], landmarks_64[12], landmarks_64[10]], axis=0)
+ boundaries["lower_left_eye"] = np.stack([landmarks_64[8], landmarks_64[13], landmarks_64[10]], axis=0)
+ boundaries["upper_right_eye"] = np.stack([landmarks_64[9], landmarks_64[14], landmarks_64[11]], axis=0)
+ boundaries["lower_right_eye"] = np.stack([landmarks_64[9], landmarks_64[15], landmarks_64[11]], axis=0)
+ boundaries["noise"] = np.stack([landmarks_64[18], landmarks_64[21], landmarks_64[19]], axis=0)
+ boundaries["outer_upper_lip"] = np.stack([landmarks_64[22], landmarks_64[24], landmarks_64[23]], axis=0)
+ boundaries["inner_upper_lip"] = np.stack([landmarks_64[22], landmarks_64[25], landmarks_64[23]], axis=0)
+ boundaries["outer_lower_lip"] = np.stack([landmarks_64[22], landmarks_64[26], landmarks_64[23]], axis=0)
+ boundaries["inner_lower_lip"] = np.stack([landmarks_64[22], landmarks_64[27], landmarks_64[23]], axis=0)
+ functions = {}
+
+ for key, points in boundaries.items():
+ temp = points[0]
+ new_points = points[0:1, :]
+ for point in points[1:]:
+ if point[0] == temp[0] and point[1] == temp[1]:
+ continue
+ else:
+ new_points = np.concatenate((new_points, np.expand_dims(point, 0)), axis=0)
+ temp = point
+ points = new_points
+ if points.shape[0] == 1:
+ points = np.concatenate((points, points + 0.001), axis=0)
+ k = min(4, points.shape[0])
+ functions[key] = interpolate.splprep([points[:, 0], points[:, 1]], k=k - 1, s=0)
+
+ boundary_map = np.zeros((64, 64))
+
+ fig = plt.figure(figsize=[64 / 96.0, 64 / 96.0], dpi=96)
+
+ ax = fig.add_axes([0, 0, 1, 1])
+
+ ax.axis("off")
+
+ ax.imshow(boundary_map, interpolation="nearest", cmap="gray")
+ # ax.scatter(landmarks[:, 0], landmarks[:, 1], s=1, marker=',', c='w')
+
+ for key in functions.keys():
+ xnew = np.arange(0, 1, 0.01)
+ out = interpolate.splev(xnew, functions[key][0], der=0)
+ plt.plot(out[0], out[1], ",", linewidth=1, color="w")
+
+ img = fig2data(fig)
+
+ plt.close()
+
+ sigma = 1
+ temp = 255 - img[:, :, 1]
+ temp = cv2.distanceTransform(temp, cv2.DIST_L2, cv2.DIST_MASK_PRECISE)
+ temp = temp.astype(np.float32)
+ temp = np.where(temp < 3 * sigma, np.exp(-(temp * temp) / (2 * sigma * sigma)), 0)
+
+ fig = plt.figure(figsize=[64 / 96.0, 64 / 96.0], dpi=96)
+
+ ax = fig.add_axes([0, 0, 1, 1])
+
+ ax.axis("off")
+ ax.imshow(temp, cmap="gray")
+ plt.close()
+
+ boundary_map = fig2data(fig)
+
+ sample["boundary"] = boundary_map[:, :, 0]
+
+ return sample
+
+
+class AddWeightMap(object):
+ def __call__(self, sample):
+ heatmap = sample["heatmap"]
+ boundary = sample["boundary"]
+ heatmap = np.concatenate((heatmap, np.expand_dims(boundary, axis=0)), 0)
+ weight_map = np.zeros_like(heatmap)
+ for i in range(heatmap.shape[0]):
+ weight_map[i] = generate_weight_map(weight_map[i], heatmap[i])
+ sample["weight_map"] = weight_map
+ return sample
+
+
+class ToTensor(object):
+ """Convert ndarrays in sample to Tensors."""
+
+ def __call__(self, sample):
+ image, heatmap, landmarks, boundary, weight_map = (
+ sample["image"],
+ sample["heatmap"],
+ sample["landmarks"],
+ sample["boundary"],
+ sample["weight_map"],
+ )
+
+ # swap color axis because
+ # numpy image: H x W x C
+ # torch image: C X H X W
+ if len(image.shape) == 2:
+ image = np.expand_dims(image, axis=2)
+ image_small = np.expand_dims(image_small, axis=2)
+ image = image.transpose((2, 0, 1))
+ boundary = np.expand_dims(boundary, axis=2)
+ boundary = boundary.transpose((2, 0, 1))
+ return {
+ "image": torch.from_numpy(image).float().div(255.0),
+ "heatmap": torch.from_numpy(heatmap).float(),
+ "landmarks": torch.from_numpy(landmarks).float(),
+ "boundary": torch.from_numpy(boundary).float().div(255.0),
+ "weight_map": torch.from_numpy(weight_map).float(),
+ }
+
+
+class FaceLandmarksDataset(Dataset):
+ """Face Landmarks dataset."""
+
+ def __init__(
+ self,
+ img_dir,
+ landmarks_dir,
+ num_landmarks=68,
+ gray_scale=False,
+ detect_face=False,
+ enhance=False,
+ center_shift=0,
+ transform=None,
+ ):
+ """
+ Args:
+ landmark_dir (string): Path to the mat file with landmarks saved.
+ img_dir (string): Directory with all the images.
+ transform (callable, optional): Optional transform to be applied
+ on a sample.
+ """
+ self.img_dir = img_dir
+ self.landmarks_dir = landmarks_dir
+ self.num_lanmdkars = num_landmarks
+ self.transform = transform
+ self.img_names = glob.glob(self.img_dir + "*.jpg") + glob.glob(self.img_dir + "*.png")
+ self.gray_scale = gray_scale
+ self.detect_face = detect_face
+ self.enhance = enhance
+ self.center_shift = center_shift
+ if self.detect_face:
+ self.face_detector = MTCNN(thresh=[0.5, 0.6, 0.7])
+
+ def __len__(self):
+ return len(self.img_names)
+
+ def __getitem__(self, idx):
+ img_name = self.img_names[idx]
+ pil_image = Image.open(img_name)
+ if pil_image.mode != "RGB":
+ # if input is grayscale image, convert it to 3 channel image
+ if self.enhance:
+ pil_image = power_transform(pil_image, 0.5)
+ temp_image = Image.new("RGB", pil_image.size)
+ temp_image.paste(pil_image)
+ pil_image = temp_image
+ image = np.array(pil_image)
+ if self.gray_scale:
+ image = rgb2gray(image)
+ image = np.expand_dims(image, axis=2)
+ image = np.concatenate((image, image, image), axis=2)
+ image = image * 255.0
+ image = image.astype(np.uint8)
+ if not self.detect_face:
+ center = [450 // 2, 450 // 2 + 0]
+ if self.center_shift != 0:
+ center[0] += int(np.random.uniform(-self.center_shift, self.center_shift))
+ center[1] += int(np.random.uniform(-self.center_shift, self.center_shift))
+ scale = 1.8
+ else:
+ detected_faces = self.face_detector.detect_image(image)
+ if len(detected_faces) > 0:
+ box = detected_faces[0]
+ left, top, right, bottom, _ = box
+ center = [right - (right - left) / 2.0, bottom - (bottom - top) / 2.0]
+ center[1] = center[1] - (bottom - top) * 0.12
+ scale = (right - left + bottom - top) / 195.0
+ else:
+ center = [450 // 2, 450 // 2 + 0]
+ scale = 1.8
+ if self.center_shift != 0:
+ shift = self.center * self.center_shift / 450
+ center[0] += int(np.random.uniform(-shift, shift))
+ center[1] += int(np.random.uniform(-shift, shift))
+ base_name = os.path.basename(img_name)
+ landmarks_base_name = base_name[:-4] + "_pts.mat"
+ landmarks_name = os.path.join(self.landmarks_dir, landmarks_base_name)
+ if os.path.isfile(landmarks_name):
+ mat_data = sio.loadmat(landmarks_name)
+ landmarks = mat_data["pts_2d"]
+ elif os.path.isfile(landmarks_name[:-8] + ".pts.npy"):
+ landmarks = np.load(landmarks_name[:-8] + ".pts.npy")
+ else:
+ landmarks = []
+ heatmap = []
+
+ if landmarks != []:
+ new_image, new_landmarks = cv_crop(image, landmarks, center, scale, 256, self.center_shift)
+ tries = 0
+ while self.center_shift != 0 and tries < 5 and (np.max(new_landmarks) > 240 or np.min(new_landmarks) < 15):
+ center = [450 // 2, 450 // 2 + 0]
+ scale += 0.05
+ center[0] += int(np.random.uniform(-self.center_shift, self.center_shift))
+ center[1] += int(np.random.uniform(-self.center_shift, self.center_shift))
+
+ new_image, new_landmarks = cv_crop(image, landmarks, center, scale, 256, self.center_shift)
+ tries += 1
+ if np.max(new_landmarks) > 250 or np.min(new_landmarks) < 5:
+ center = [450 // 2, 450 // 2 + 0]
+ scale = 2.25
+ new_image, new_landmarks = cv_crop(image, landmarks, center, scale, 256, 100)
+ assert np.min(new_landmarks) > 0 and np.max(new_landmarks) < 256, "Landmarks out of boundary!"
+ image = new_image
+ landmarks = new_landmarks
+ heatmap = np.zeros((self.num_lanmdkars, 64, 64))
+ for i in range(self.num_lanmdkars):
+ if landmarks[i][0] > 0:
+ heatmap[i] = draw_gaussian(heatmap[i], landmarks[i] / 4.0 + 1, 1)
+ sample = {"image": image, "heatmap": heatmap, "landmarks": landmarks}
+ if self.transform:
+ sample = self.transform(sample)
+
+ return sample
+
+
+def get_dataset(
+ val_img_dir,
+ val_landmarks_dir,
+ batch_size,
+ num_landmarks=68,
+ rotation=0,
+ scale=0,
+ center_shift=0,
+ random_flip=False,
+ brightness=0,
+ contrast=0,
+ saturation=0,
+ blur=False,
+ noise=False,
+ jpeg_effect=False,
+ random_occlusion=False,
+ gray_scale=False,
+ detect_face=False,
+ enhance=False,
+):
+ val_transforms = transforms.Compose([AddBoundary(num_landmarks), AddWeightMap(), ToTensor()])
+
+ val_dataset = FaceLandmarksDataset(
+ val_img_dir,
+ val_landmarks_dir,
+ num_landmarks=num_landmarks,
+ gray_scale=gray_scale,
+ detect_face=detect_face,
+ enhance=enhance,
+ transform=val_transforms,
+ )
+
+ val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=6)
+ data_loaders = {"val": val_dataloader}
+ dataset_sizes = {}
+ dataset_sizes["val"] = len(val_dataset)
+ return data_loaders, dataset_sizes
diff --git a/AdaptiveWingLoss/core/evaler.py b/AdaptiveWingLoss/core/evaler.py
new file mode 100755
index 0000000000000000000000000000000000000000..9a1a3c26560dc6a34067df513f7cc85798fa3b25
--- /dev/null
+++ b/AdaptiveWingLoss/core/evaler.py
@@ -0,0 +1,125 @@
+import matplotlib
+
+matplotlib.use("Agg")
+import math
+import torch
+import copy
+import time
+from torch.autograd import Variable
+import shutil
+from skimage import io
+import numpy as np
+from utils.utils import fan_NME, show_landmarks, get_preds_fromhm
+from PIL import Image, ImageDraw
+import os
+import sys
+import cv2
+import matplotlib.pyplot as plt
+
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+
+def eval_model(
+ model, dataloaders, dataset_sizes, writer, use_gpu=True, epoches=5, dataset="val", save_path="./", num_landmarks=68
+):
+ global_nme = 0
+ model.eval()
+ for epoch in range(epoches):
+ running_loss = 0
+ step = 0
+ total_nme = 0
+ total_count = 0
+ fail_count = 0
+ nmes = []
+ # running_corrects = 0
+
+ # Iterate over data.
+ with torch.no_grad():
+ for data in dataloaders[dataset]:
+ total_runtime = 0
+ run_count = 0
+ step_start = time.time()
+ step += 1
+ # get the inputs
+ inputs = data["image"].type(torch.FloatTensor)
+ labels_heatmap = data["heatmap"].type(torch.FloatTensor)
+ labels_boundary = data["boundary"].type(torch.FloatTensor)
+ landmarks = data["landmarks"].type(torch.FloatTensor)
+ loss_weight_map = data["weight_map"].type(torch.FloatTensor)
+ # wrap them in Variable
+ if use_gpu:
+ inputs = inputs.to(device)
+ labels_heatmap = labels_heatmap.to(device)
+ labels_boundary = labels_boundary.to(device)
+ loss_weight_map = loss_weight_map.to(device)
+ else:
+ inputs, labels_heatmap = Variable(inputs), Variable(labels_heatmap)
+ labels_boundary = Variable(labels_boundary)
+ labels = torch.cat((labels_heatmap, labels_boundary), 1)
+ single_start = time.time()
+ outputs, boundary_channels = model(inputs)
+ single_end = time.time()
+ total_runtime += time.time() - single_start
+ run_count += 1
+ step_end = time.time()
+ for i in range(inputs.shape[0]):
+ img = inputs[i]
+ img = img.cpu().numpy()
+ img = img.transpose((1, 2, 0)) * 255.0
+ img = img.astype(np.uint8)
+ img = Image.fromarray(img)
+ # pred_heatmap = outputs[-1][i].detach().cpu()[:-1, :, :]
+ pred_heatmap = outputs[-1][:, :-1, :, :][i].detach().cpu()
+ pred_landmarks, _ = get_preds_fromhm(pred_heatmap.unsqueeze(0))
+ pred_landmarks = pred_landmarks.squeeze().numpy()
+
+ gt_landmarks = data["landmarks"][i].numpy()
+ if num_landmarks == 68:
+ left_eye = np.average(gt_landmarks[36:42], axis=0)
+ right_eye = np.average(gt_landmarks[42:48], axis=0)
+ norm_factor = np.linalg.norm(left_eye - right_eye)
+ # norm_factor = np.linalg.norm(gt_landmarks[36]- gt_landmarks[45])
+
+ elif num_landmarks == 98:
+ norm_factor = np.linalg.norm(gt_landmarks[60] - gt_landmarks[72])
+ elif num_landmarks == 19:
+ left, top = gt_landmarks[-2, :]
+ right, bottom = gt_landmarks[-1, :]
+ norm_factor = math.sqrt(abs(right - left) * abs(top - bottom))
+ gt_landmarks = gt_landmarks[:-2, :]
+ elif num_landmarks == 29:
+ # norm_factor = np.linalg.norm(gt_landmarks[8]- gt_landmarks[9])
+ norm_factor = np.linalg.norm(gt_landmarks[16] - gt_landmarks[17])
+ single_nme = (
+ np.sum(np.linalg.norm(pred_landmarks * 4 - gt_landmarks, axis=1)) / pred_landmarks.shape[0]
+ ) / norm_factor
+
+ nmes.append(single_nme)
+ total_count += 1
+ if single_nme > 0.1:
+ fail_count += 1
+ if step % 10 == 0:
+ print(
+ "Step {} Time: {:.6f} Input Mean: {:.6f} Output Mean: {:.6f}".format(
+ step, step_end - step_start, torch.mean(labels), torch.mean(outputs[0])
+ )
+ )
+ # gt_landmarks = landmarks.numpy()
+ # pred_heatmap = outputs[-1].to('cpu').numpy()
+ gt_landmarks = landmarks
+ batch_nme = fan_NME(outputs[-1][:, :-1, :, :].detach().cpu(), gt_landmarks, num_landmarks)
+ # batch_nme = 0
+ total_nme += batch_nme
+ epoch_nme = total_nme / dataset_sizes["val"]
+ global_nme += epoch_nme
+ nme_save_path = os.path.join(save_path, "nme_log.npy")
+ np.save(nme_save_path, np.array(nmes))
+ print(
+ "NME: {:.6f} Failure Rate: {:.6f} Total Count: {:.6f} Fail Count: {:.6f}".format(
+ epoch_nme, fail_count / total_count, total_count, fail_count
+ )
+ )
+ print("Evaluation done! Average NME: {:.6f}".format(global_nme / epoches))
+ print("Everage runtime for a single batch: {:.6f}".format(total_runtime / run_count))
+ return model
diff --git a/AdaptiveWingLoss/core/models.py b/AdaptiveWingLoss/core/models.py
new file mode 100755
index 0000000000000000000000000000000000000000..f2c3d5c0616db1869a8ff9c9e74dffaf4c1c8828
--- /dev/null
+++ b/AdaptiveWingLoss/core/models.py
@@ -0,0 +1,239 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from AdaptiveWingLoss.core.coord_conv import CoordConvTh
+
+
+def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False, dilation=1):
+ "3x3 convolution with padding"
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=strd, padding=padding, bias=bias, dilation=dilation)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ # self.bn1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ # self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ # out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ # out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ConvBlock(nn.Module):
+ def __init__(self, in_planes, out_planes):
+ super(ConvBlock, self).__init__()
+ self.bn1 = nn.BatchNorm2d(in_planes)
+ self.conv1 = conv3x3(in_planes, int(out_planes / 2))
+ self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
+ self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4), padding=1, dilation=1)
+ self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
+ self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4), padding=1, dilation=1)
+
+ if in_planes != out_planes:
+ self.downsample = nn.Sequential(
+ nn.BatchNorm2d(in_planes),
+ nn.ReLU(True),
+ nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False),
+ )
+ else:
+ self.downsample = None
+
+ def forward(self, x):
+ residual = x
+
+ out1 = self.bn1(x)
+ out1 = F.relu(out1, True)
+ out1 = self.conv1(out1)
+
+ out2 = self.bn2(out1)
+ out2 = F.relu(out2, True)
+ out2 = self.conv2(out2)
+
+ out3 = self.bn3(out2)
+ out3 = F.relu(out3, True)
+ out3 = self.conv3(out3)
+
+ out3 = torch.cat((out1, out2, out3), 1)
+
+ if self.downsample is not None:
+ residual = self.downsample(residual)
+
+ out3 += residual
+
+ return out3
+
+
+class HourGlass(nn.Module):
+ def __init__(self, num_modules, depth, num_features, first_one=False):
+ super(HourGlass, self).__init__()
+ self.num_modules = num_modules
+ self.depth = depth
+ self.features = num_features
+ self.coordconv = CoordConvTh(
+ x_dim=64,
+ y_dim=64,
+ with_r=True,
+ with_boundary=True,
+ in_channels=256,
+ first_one=first_one,
+ out_channels=256,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ self._generate_network(self.depth)
+
+ def _generate_network(self, level):
+ self.add_module("b1_" + str(level), ConvBlock(256, 256))
+
+ self.add_module("b2_" + str(level), ConvBlock(256, 256))
+
+ if level > 1:
+ self._generate_network(level - 1)
+ else:
+ self.add_module("b2_plus_" + str(level), ConvBlock(256, 256))
+
+ self.add_module("b3_" + str(level), ConvBlock(256, 256))
+
+ def _forward(self, level, inp):
+ # Upper branch
+ up1 = inp
+ up1 = self._modules["b1_" + str(level)](up1)
+
+ # Lower branch
+ low1 = F.avg_pool2d(inp, 2, stride=2)
+ low1 = self._modules["b2_" + str(level)](low1)
+
+ if level > 1:
+ low2 = self._forward(level - 1, low1)
+ else:
+ low2 = low1
+ low2 = self._modules["b2_plus_" + str(level)](low2)
+
+ low3 = low2
+ low3 = self._modules["b3_" + str(level)](low3)
+
+ up2 = F.upsample(low3, scale_factor=2, mode="nearest")
+
+ return up1 + up2
+
+ def forward(self, x, heatmap):
+ x, last_channel = self.coordconv(x, heatmap)
+ return self._forward(self.depth, x), last_channel
+
+
+class FAN(nn.Module):
+ def __init__(self, num_modules=1, end_relu=False, gray_scale=False, num_landmarks=68):
+ super(FAN, self).__init__()
+ self.num_modules = num_modules
+ self.gray_scale = gray_scale
+ self.end_relu = end_relu
+ self.num_landmarks = num_landmarks
+
+ # Base part
+ if self.gray_scale:
+ self.conv1 = CoordConvTh(
+ x_dim=256,
+ y_dim=256,
+ with_r=True,
+ with_boundary=False,
+ in_channels=3,
+ out_channels=64,
+ kernel_size=7,
+ stride=2,
+ padding=3,
+ )
+ else:
+ self.conv1 = CoordConvTh(
+ x_dim=256,
+ y_dim=256,
+ with_r=True,
+ with_boundary=False,
+ in_channels=3,
+ out_channels=64,
+ kernel_size=7,
+ stride=2,
+ padding=3,
+ )
+ self.bn1 = nn.BatchNorm2d(64)
+ self.conv2 = ConvBlock(64, 128)
+ self.conv3 = ConvBlock(128, 128)
+ self.conv4 = ConvBlock(128, 256)
+
+ # Stacking part
+ for hg_module in range(self.num_modules):
+ if hg_module == 0:
+ first_one = True
+ else:
+ first_one = False
+ self.add_module("m" + str(hg_module), HourGlass(1, 4, 256, first_one))
+ self.add_module("top_m_" + str(hg_module), ConvBlock(256, 256))
+ self.add_module("conv_last" + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
+ self.add_module("bn_end" + str(hg_module), nn.BatchNorm2d(256))
+ self.add_module("l" + str(hg_module), nn.Conv2d(256, num_landmarks + 1, kernel_size=1, stride=1, padding=0))
+
+ if hg_module < self.num_modules - 1:
+ self.add_module("bl" + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
+ self.add_module(
+ "al" + str(hg_module), nn.Conv2d(num_landmarks + 1, 256, kernel_size=1, stride=1, padding=0)
+ )
+
+ def forward(self, x):
+ x, _ = self.conv1(x)
+ x = F.relu(self.bn1(x), True)
+ # x = F.relu(self.bn1(self.conv1(x)), True)
+ x = F.avg_pool2d(self.conv2(x), 2, stride=2)
+ x = self.conv3(x)
+ x = self.conv4(x)
+
+ previous = x
+
+ outputs = []
+ boundary_channels = []
+ tmp_out = None
+ for i in range(self.num_modules):
+ hg, boundary_channel = self._modules["m" + str(i)](previous, tmp_out)
+
+ ll = hg
+ ll = self._modules["top_m_" + str(i)](ll)
+
+ ll = F.relu(self._modules["bn_end" + str(i)](self._modules["conv_last" + str(i)](ll)), True)
+
+ # Predict heatmaps
+ tmp_out = self._modules["l" + str(i)](ll)
+ if self.end_relu:
+ tmp_out = F.relu(tmp_out) # HACK: Added relu
+ outputs.append(tmp_out)
+ boundary_channels.append(boundary_channel)
+
+ if i < self.num_modules - 1:
+ ll = self._modules["bl" + str(i)](ll)
+ tmp_out_ = self._modules["al" + str(i)](tmp_out)
+ previous = previous + ll + tmp_out_
+
+ return outputs, boundary_channels
diff --git a/AdaptiveWingLoss/utils/__init__.py b/AdaptiveWingLoss/utils/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/AdaptiveWingLoss/utils/utils.py b/AdaptiveWingLoss/utils/utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..9daa0be2b71c5e7d92679f38692a891ac7e3cc0a
--- /dev/null
+++ b/AdaptiveWingLoss/utils/utils.py
@@ -0,0 +1,437 @@
+from __future__ import division
+from __future__ import print_function
+
+import math
+import os
+import sys
+
+import cv2
+import matplotlib
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from PIL import Image
+from scipy import ndimage
+from skimage import io
+from skimage import transform as ski_transform
+from torch.utils.data import DataLoader
+from torch.utils.data import Dataset
+from torchvision import transforms
+from torchvision import utils
+
+
+def _gaussian(
+ size=3,
+ sigma=0.25,
+ amplitude=1,
+ normalize=False,
+ width=None,
+ height=None,
+ sigma_horz=None,
+ sigma_vert=None,
+ mean_horz=0.5,
+ mean_vert=0.5,
+):
+ # handle some defaults
+ if width is None:
+ width = size
+ if height is None:
+ height = size
+ if sigma_horz is None:
+ sigma_horz = sigma
+ if sigma_vert is None:
+ sigma_vert = sigma
+ center_x = mean_horz * width + 0.5
+ center_y = mean_vert * height + 0.5
+ gauss = np.empty((height, width), dtype=np.float32)
+ # generate kernel
+ for i in range(height):
+ for j in range(width):
+ gauss[i][j] = amplitude * math.exp(
+ -(
+ math.pow((j + 1 - center_x) / (sigma_horz * width), 2) / 2.0
+ + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0
+ )
+ )
+ if normalize:
+ gauss = gauss / np.sum(gauss)
+ return gauss
+
+
+def draw_gaussian(image, point, sigma):
+ # Check if the gaussian is inside
+ ul = [np.floor(np.floor(point[0]) - 3 * sigma), np.floor(np.floor(point[1]) - 3 * sigma)]
+ br = [np.floor(np.floor(point[0]) + 3 * sigma), np.floor(np.floor(point[1]) + 3 * sigma)]
+ if ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1:
+ return image
+ size = 6 * sigma + 1
+ g = _gaussian(size)
+ g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
+ g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
+ img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
+ img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
+ assert g_x[0] > 0 and g_y[1] > 0
+ correct = False
+ while not correct:
+ try:
+ image[img_y[0] - 1 : img_y[1], img_x[0] - 1 : img_x[1]] = (
+ image[img_y[0] - 1 : img_y[1], img_x[0] - 1 : img_x[1]] + g[g_y[0] - 1 : g_y[1], g_x[0] - 1 : g_x[1]]
+ )
+ correct = True
+ except:
+ print(
+ "img_x: {}, img_y: {}, g_x:{}, g_y:{}, point:{}, g_shape:{}, ul:{}, br:{}".format(
+ img_x, img_y, g_x, g_y, point, g.shape, ul, br
+ )
+ )
+ ul = [np.floor(np.floor(point[0]) - 3 * sigma), np.floor(np.floor(point[1]) - 3 * sigma)]
+ br = [np.floor(np.floor(point[0]) + 3 * sigma), np.floor(np.floor(point[1]) + 3 * sigma)]
+ g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
+ g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
+ img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
+ img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
+ pass
+ image[image > 1] = 1
+ return image
+
+
+def transform(point, center, scale, resolution, rotation=0, invert=False):
+ _pt = np.ones(3)
+ _pt[0] = point[0]
+ _pt[1] = point[1]
+
+ h = 200.0 * scale
+ t = np.eye(3)
+ t[0, 0] = resolution / h
+ t[1, 1] = resolution / h
+ t[0, 2] = resolution * (-center[0] / h + 0.5)
+ t[1, 2] = resolution * (-center[1] / h + 0.5)
+
+ if rotation != 0:
+ rotation = -rotation
+ r = np.eye(3)
+ ang = rotation * math.pi / 180.0
+ s = math.sin(ang)
+ c = math.cos(ang)
+ r[0][0] = c
+ r[0][1] = -s
+ r[1][0] = s
+ r[1][1] = c
+
+ t_ = np.eye(3)
+ t_[0][2] = -resolution / 2.0
+ t_[1][2] = -resolution / 2.0
+ t_inv = torch.eye(3)
+ t_inv[0][2] = resolution / 2.0
+ t_inv[1][2] = resolution / 2.0
+ t = reduce(np.matmul, [t_inv, r, t_, t])
+
+ if invert:
+ t = np.linalg.inv(t)
+ new_point = (np.matmul(t, _pt))[0:2]
+
+ return new_point.astype(int)
+
+
+def cv_crop(image, landmarks, center, scale, resolution=256, center_shift=0):
+ new_image = cv2.copyMakeBorder(
+ image, center_shift, center_shift, center_shift, center_shift, cv2.BORDER_CONSTANT, value=[0, 0, 0]
+ )
+ new_landmarks = landmarks.copy()
+ if center_shift != 0:
+ center[0] += center_shift
+ center[1] += center_shift
+ new_landmarks = new_landmarks + center_shift
+ length = 200 * scale
+ top = int(center[1] - length // 2)
+ bottom = int(center[1] + length // 2)
+ left = int(center[0] - length // 2)
+ right = int(center[0] + length // 2)
+ y_pad = abs(min(top, new_image.shape[0] - bottom, 0))
+ x_pad = abs(min(left, new_image.shape[1] - right, 0))
+ top, bottom, left, right = top + y_pad, bottom + y_pad, left + x_pad, right + x_pad
+ new_image = cv2.copyMakeBorder(new_image, y_pad, y_pad, x_pad, x_pad, cv2.BORDER_CONSTANT, value=[0, 0, 0])
+ new_image = new_image[top:bottom, left:right]
+ new_image = cv2.resize(new_image, dsize=(int(resolution), int(resolution)), interpolation=cv2.INTER_LINEAR)
+ new_landmarks[:, 0] = (new_landmarks[:, 0] + x_pad - left) * resolution / length
+ new_landmarks[:, 1] = (new_landmarks[:, 1] + y_pad - top) * resolution / length
+ return new_image, new_landmarks
+
+
+def cv_rotate(image, landmarks, heatmap, rot, scale, resolution=256):
+ img_mat = cv2.getRotationMatrix2D((resolution // 2, resolution // 2), rot, scale)
+ ones = np.ones(shape=(landmarks.shape[0], 1))
+ stacked_landmarks = np.hstack([landmarks, ones])
+ new_landmarks = img_mat.dot(stacked_landmarks.T).T
+ if np.max(new_landmarks) > 255 or np.min(new_landmarks) < 0:
+ return image, landmarks, heatmap
+ else:
+ new_image = cv2.warpAffine(image, img_mat, (resolution, resolution))
+ if heatmap is not None:
+ new_heatmap = np.zeros((heatmap.shape[0], 64, 64))
+ for i in range(heatmap.shape[0]):
+ if new_landmarks[i][0] > 0:
+ new_heatmap[i] = draw_gaussian(new_heatmap[i], new_landmarks[i] / 4.0 + 1, 1)
+ return new_image, new_landmarks, new_heatmap
+
+
+def show_landmarks(image, heatmap, gt_landmarks, gt_heatmap):
+ """Show image with pred_landmarks"""
+ pred_landmarks = []
+ pred_landmarks, _ = get_preds_fromhm(torch.from_numpy(heatmap).unsqueeze(0))
+ pred_landmarks = pred_landmarks.squeeze() * 4
+
+ # pred_landmarks2 = get_preds_fromhm2(heatmap)
+ heatmap = np.max(gt_heatmap, axis=0)
+ heatmap = heatmap / np.max(heatmap)
+ # image = ski_transform.resize(image, (64, 64))*255
+ image = image.astype(np.uint8)
+ heatmap = np.max(gt_heatmap, axis=0)
+ heatmap = ski_transform.resize(heatmap, (image.shape[0], image.shape[1]))
+ heatmap *= 255
+ heatmap = heatmap.astype(np.uint8)
+ heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
+ plt.imshow(image)
+ plt.scatter(gt_landmarks[:, 0], gt_landmarks[:, 1], s=0.5, marker=".", c="g")
+ plt.scatter(pred_landmarks[:, 0], pred_landmarks[:, 1], s=0.5, marker=".", c="r")
+ plt.pause(0.001) # pause a bit so that plots are updated
+
+
+def fan_NME(pred_heatmaps, gt_landmarks, num_landmarks=68):
+ """
+ Calculate total NME for a batch of data
+
+ Args:
+ pred_heatmaps: torch tensor of size [batch, points, height, width]
+ gt_landmarks: torch tesnsor of size [batch, points, x, y]
+
+ Returns:
+ nme: sum of nme for this batch
+ """
+ nme = 0
+ pred_landmarks, _ = get_preds_fromhm(pred_heatmaps)
+ pred_landmarks = pred_landmarks.numpy()
+ gt_landmarks = gt_landmarks.numpy()
+ for i in range(pred_landmarks.shape[0]):
+ pred_landmark = pred_landmarks[i] * 4.0
+ gt_landmark = gt_landmarks[i]
+
+ if num_landmarks == 68:
+ left_eye = np.average(gt_landmark[36:42], axis=0)
+ right_eye = np.average(gt_landmark[42:48], axis=0)
+ norm_factor = np.linalg.norm(left_eye - right_eye)
+ # norm_factor = np.linalg.norm(gt_landmark[36]- gt_landmark[45])
+ elif num_landmarks == 98:
+ norm_factor = np.linalg.norm(gt_landmark[60] - gt_landmark[72])
+ elif num_landmarks == 19:
+ left, top = gt_landmark[-2, :]
+ right, bottom = gt_landmark[-1, :]
+ norm_factor = math.sqrt(abs(right - left) * abs(top - bottom))
+ gt_landmark = gt_landmark[:-2, :]
+ elif num_landmarks == 29:
+ # norm_factor = np.linalg.norm(gt_landmark[8]- gt_landmark[9])
+ norm_factor = np.linalg.norm(gt_landmark[16] - gt_landmark[17])
+ nme += (np.sum(np.linalg.norm(pred_landmark - gt_landmark, axis=1)) / pred_landmark.shape[0]) / norm_factor
+ return nme
+
+
+def fan_NME_hm(pred_heatmaps, gt_heatmaps, num_landmarks=68):
+ """
+ Calculate total NME for a batch of data
+
+ Args:
+ pred_heatmaps: torch tensor of size [batch, points, height, width]
+ gt_landmarks: torch tesnsor of size [batch, points, x, y]
+
+ Returns:
+ nme: sum of nme for this batch
+ """
+ nme = 0
+ pred_landmarks, _ = get_index_fromhm(pred_heatmaps)
+ pred_landmarks = pred_landmarks.numpy()
+ gt_landmarks = gt_landmarks.numpy()
+ for i in range(pred_landmarks.shape[0]):
+ pred_landmark = pred_landmarks[i] * 4.0
+ gt_landmark = gt_landmarks[i]
+ if num_landmarks == 68:
+ left_eye = np.average(gt_landmark[36:42], axis=0)
+ right_eye = np.average(gt_landmark[42:48], axis=0)
+ norm_factor = np.linalg.norm(left_eye - right_eye)
+ else:
+ norm_factor = np.linalg.norm(gt_landmark[60] - gt_landmark[72])
+ nme += (np.sum(np.linalg.norm(pred_landmark - gt_landmark, axis=1)) / pred_landmark.shape[0]) / norm_factor
+ return nme
+
+
+def power_transform(img, power):
+ img = np.array(img)
+ img_new = np.power((img / 255.0), power) * 255.0
+ img_new = img_new.astype(np.uint8)
+ img_new = Image.fromarray(img_new)
+ return img_new
+
+
+def get_preds_fromhm(hm, center=None, scale=None, rot=None):
+ max, idx = torch.max(hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
+ idx += 1
+ preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
+ preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
+ preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
+
+ for i in range(preds.size(0)):
+ for j in range(preds.size(1)):
+ hm_ = hm[i, j, :]
+ pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
+ if pX > 0 and pX < 63 and pY > 0 and pY < 63:
+ diff = torch.FloatTensor([hm_[pY, pX + 1] - hm_[pY, pX - 1], hm_[pY + 1, pX] - hm_[pY - 1, pX]])
+ preds[i, j].add_(diff.sign_().mul_(0.25))
+
+ preds.add_(-0.5)
+
+ preds_orig = torch.zeros(preds.size())
+ if center is not None and scale is not None:
+ for i in range(hm.size(0)):
+ for j in range(hm.size(1)):
+ preds_orig[i, j] = transform(preds[i, j], center, scale, hm.size(2), rot, True)
+
+ return preds, preds_orig
+
+
+def get_index_fromhm(hm):
+ max, idx = torch.max(hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
+ preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
+ preds[..., 0].remainder_(hm.size(3))
+ preds[..., 1].div_(hm.size(2)).floor_()
+
+ for i in range(preds.size(0)):
+ for j in range(preds.size(1)):
+ hm_ = hm[i, j, :]
+ pX, pY = int(preds[i, j, 0]), int(preds[i, j, 1])
+ if pX > 0 and pX < 63 and pY > 0 and pY < 63:
+ diff = torch.FloatTensor([hm_[pY, pX + 1] - hm_[pY, pX - 1], hm_[pY + 1, pX] - hm_[pY - 1, pX]])
+ preds[i, j].add_(diff.sign_().mul_(0.25))
+
+ return preds
+
+
+def shuffle_lr(parts, num_landmarks=68, pairs=None):
+ if num_landmarks == 68:
+ if pairs is None:
+ pairs = [
+ [0, 16],
+ [1, 15],
+ [2, 14],
+ [3, 13],
+ [4, 12],
+ [5, 11],
+ [6, 10],
+ [7, 9],
+ [17, 26],
+ [18, 25],
+ [19, 24],
+ [20, 23],
+ [21, 22],
+ [36, 45],
+ [37, 44],
+ [38, 43],
+ [39, 42],
+ [41, 46],
+ [40, 47],
+ [31, 35],
+ [32, 34],
+ [50, 52],
+ [49, 53],
+ [48, 54],
+ [61, 63],
+ [60, 64],
+ [67, 65],
+ [59, 55],
+ [58, 56],
+ ]
+ elif num_landmarks == 98:
+ if pairs is None:
+ pairs = [
+ [0, 32],
+ [1, 31],
+ [2, 30],
+ [3, 29],
+ [4, 28],
+ [5, 27],
+ [6, 26],
+ [7, 25],
+ [8, 24],
+ [9, 23],
+ [10, 22],
+ [11, 21],
+ [12, 20],
+ [13, 19],
+ [14, 18],
+ [15, 17],
+ [33, 46],
+ [34, 45],
+ [35, 44],
+ [36, 43],
+ [37, 42],
+ [38, 50],
+ [39, 49],
+ [40, 48],
+ [41, 47],
+ [60, 72],
+ [61, 71],
+ [62, 70],
+ [63, 69],
+ [64, 68],
+ [65, 75],
+ [66, 74],
+ [67, 73],
+ [96, 97],
+ [55, 59],
+ [56, 58],
+ [76, 82],
+ [77, 81],
+ [78, 80],
+ [88, 92],
+ [89, 91],
+ [95, 93],
+ [87, 83],
+ [86, 84],
+ ]
+ elif num_landmarks == 19:
+ if pairs is None:
+ pairs = [[0, 5], [1, 4], [2, 3], [6, 11], [7, 10], [8, 9], [12, 14], [15, 17]]
+ elif num_landmarks == 29:
+ if pairs is None:
+ pairs = [[0, 1], [4, 6], [5, 7], [2, 3], [8, 9], [12, 14], [16, 17], [13, 15], [10, 11], [18, 19], [22, 23]]
+ for matched_p in pairs:
+ idx1, idx2 = matched_p[0], matched_p[1]
+ tmp = np.copy(parts[idx1])
+ np.copyto(parts[idx1], parts[idx2])
+ np.copyto(parts[idx2], tmp)
+ return parts
+
+
+def generate_weight_map(weight_map, heatmap):
+
+ k_size = 3
+ dilate = ndimage.grey_dilation(heatmap, size=(k_size, k_size))
+ weight_map[np.where(dilate > 0.2)] = 1
+ return weight_map
+
+
+def fig2data(fig):
+ """
+ @brief Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it
+ @param fig a matplotlib figure
+ @return a numpy 3D array of RGBA values
+ """
+ # draw the renderer
+ fig.canvas.draw()
+
+ # Get the RGB buffer from the figure
+ w, h = fig.canvas.get_width_height()
+ buf = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8)
+ buf.shape = (w, h, 3)
+
+ # canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode
+ buf = np.roll(buf, 3, axis=2)
+ return buf
diff --git a/Deep3DFaceRecon_pytorch/LICENSE b/Deep3DFaceRecon_pytorch/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..7e2b184589f4acbeb98eecd453c28b41ff551471
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 Sicheng Xu
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/Deep3DFaceRecon_pytorch/README.md b/Deep3DFaceRecon_pytorch/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8a1b306c8cdec09533ed4cd27ca0799b7f9ab53f
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/README.md
@@ -0,0 +1,256 @@
+## Accurate 3D Face Reconstruction with Weakly-Supervised Learning: From Single Image to Image Set —— PyTorch implementation ##
+
+
+
+
+
+This is an unofficial official pytorch implementation of the following paper:
+
+Y. Deng, J. Yang, S. Xu, D. Chen, Y. Jia, and X. Tong, [Accurate 3D Face Reconstruction with Weakly-Supervised Learning: From Single Image to Image Set](https://arxiv.org/abs/1903.08527), IEEE Computer Vision and Pattern Recognition Workshop (CVPRW) on Analysis and Modeling of Faces and Gestures (AMFG), 2019. (**_Best Paper Award!_**)
+
+The method enforces a hybrid-level weakly-supervised training for CNN-based 3D face reconstruction. It is fast, accurate, and robust to pose and occlussions. It achieves state-of-the-art performance on multiple datasets such as FaceWarehouse, MICC Florence and NoW Challenge.
+
+
+For the original tensorflow implementation, check this [repo](https://github.com/microsoft/Deep3DFaceReconstruction).
+
+This implementation is written by S. Xu.
+
+## Performance
+
+### ● Reconstruction accuracy
+
+The pytorch implementation achieves lower shape reconstruction error (9% improvement) compare to the [original tensorflow implementation](https://github.com/microsoft/Deep3DFaceReconstruction). Quantitative evaluation (average shape errors in mm) on several benchmarks is as follows:
+
+|Method|FaceWareHouse|MICC Florence | NoW Challenge |
+|:----:|:-----------:|:-----------:|:-----------:|
+|Deep3DFace Tensorflow | 1.81±0.50 | 1.67±0.50 | 1.54±1.29 |
+|**Deep3DFace PyTorch** |**1.64±0.50**|**1.53±0.45**| **1.41±1.21** |
+
+The comparison result with state-of-the-art public 3D face reconstruction methods on the NoW face benchmark is as follows:
+|Rank|Method|Median(mm) | Mean(mm) | Std(mm) |
+|:----:|:-----------:|:-----------:|:-----------:|:-----------:|
+| 1. | [DECA\[Feng et al., SIGGRAPH 2021\]](https://github.com/YadiraF/DECA)|1.09|1.38|1.18|
+| **2.** | **Deep3DFace PyTorch**|**1.11**|**1.41**|**1.21**|
+| 3. | [RingNet [Sanyal et al., CVPR 2019]](https://github.com/soubhiksanyal/RingNet) | 1.21 | 1.53 | 1.31 |
+| 4. | [Deep3DFace [Deng et al., CVPRW 2019]](https://github.com/microsoft/Deep3DFaceReconstruction) | 1.23 | 1.54 | 1.29 |
+| 5. | [3DDFA-V2 [Guo et al., ECCV 2020]](https://github.com/cleardusk/3DDFA_V2) | 1.23 | 1.57 | 1.39 |
+| 6. | [MGCNet [Shang et al., ECCV 2020]](https://github.com/jiaxiangshang/MGCNet) | 1.31 | 1.87 | 2.63 |
+| 7. | [PRNet [Feng et al., ECCV 2018]](https://github.com/YadiraF/PRNet) | 1.50 | 1.98 | 1.88 |
+| 8. | [3DMM-CNN [Tran et al., CVPR 2017]](https://github.com/anhttran/3dmm_cnn) | 1.84 | 2.33 | 2.05 |
+
+For more details about the evaluation, check [Now Challenge](https://ringnet.is.tue.mpg.de/challenge.html) website.
+
+**_A recent benchmark [REALY](https://www.realy3dface.com/) indicates that our method still has the SOTA performance! You can check their paper and website for more details._**
+
+### ● Visual quality
+The pytorch implementation achieves better visual consistency with the input images compare to the original tensorflow version.
+
+
+
+
+
+### ● Speed
+The training speed is on par with the original tensorflow implementation. For more information, see [here](https://github.com/sicxu/Deep3DFaceRecon_pytorch#train-the-face-reconstruction-network).
+
+## Major changes
+
+### ● Differentiable renderer
+
+We use [Nvdiffrast](https://nvlabs.github.io/nvdiffrast/) which is a pytorch library that provides high-performance primitive operations for rasterization-based differentiable rendering. The original tensorflow implementation used [tf_mesh_renderer](https://github.com/google/tf_mesh_renderer) instead.
+
+### ● Face recognition model
+
+We use [Arcface](https://github.com/deepinsight/insightface/tree/master/recognition/arcface_torch), a state-of-the-art face recognition model, for perceptual loss computation. By contrast, the original tensorflow implementation used [Facenet](https://github.com/davidsandberg/facenet).
+
+### ● Training configuration
+
+Data augmentation is used in the training process which contains random image shifting, scaling, rotation, and flipping. We also enlarge the training batchsize from 5 to 32 to stablize the training process.
+
+### ● Training data
+
+We use an extra high quality face image dataset [FFHQ](https://github.com/NVlabs/ffhq-dataset) to increase the diversity of training data.
+
+## Requirements
+**This implementation is only tested under Ubuntu environment with Nvidia GPUs and CUDA installed.**
+
+## Installation
+1. Clone the repository and set up a conda environment with all dependencies as follows:
+```
+git clone https://github.com/sicxu/Deep3DFaceRecon_pytorch.git
+cd Deep3DFaceRecon_pytorch
+conda env create -f environment.yml
+source activate deep3d_pytorch
+```
+
+2. Install Nvdiffrast library:
+```
+git clone https://github.com/NVlabs/nvdiffrast
+cd nvdiffrast # ./Deep3DFaceRecon_pytorch/nvdiffrast
+pip install .
+```
+
+3. Install Arcface Pytorch:
+```
+cd .. # ./Deep3DFaceRecon_pytorch
+git clone https://github.com/deepinsight/insightface.git
+cp -r ./insightface/recognition/arcface_torch ./models/
+```
+## Inference with a pre-trained model
+
+### Prepare prerequisite models
+1. Our method uses [Basel Face Model 2009 (BFM09)](https://faces.dmi.unibas.ch/bfm/main.php?nav=1-0&id=basel_face_model) to represent 3d faces. Get access to BFM09 using this [link](https://faces.dmi.unibas.ch/bfm/main.php?nav=1-2&id=downloads). After getting the access, download "01_MorphableModel.mat". In addition, we use an Expression Basis provided by [Guo et al.](https://github.com/Juyong/3DFace). Download the Expression Basis (Exp_Pca.bin) using this [link (google drive)](https://drive.google.com/file/d/1bw5Xf8C12pWmcMhNEu6PtsYVZkVucEN6/view?usp=sharing). Organize all files into the following structure:
+```
+Deep3DFaceRecon_pytorch
+│
+└─── BFM
+ │
+ └─── 01_MorphableModel.mat
+ │
+ └─── Exp_Pca.bin
+ |
+ └─── ...
+```
+2. We provide a model trained on a combination of [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html),
+[LFW](http://vis-www.cs.umass.edu/lfw/), [300WLP](http://www.cbsr.ia.ac.cn/users/xiangyuzhu/projects/3DDFA/main.htm),
+[IJB-A](https://www.nist.gov/programs-projects/face-challenges), [LS3D-W](https://www.adrianbulat.com/face-alignment), and [FFHQ](https://github.com/NVlabs/ffhq-dataset) datasets. Download the pre-trained model using this [link (google drive)](https://drive.google.com/drive/folders/1liaIxn9smpudjjqMaWWRpP0mXRW_qRPP?usp=sharing) and organize the directory into the following structure:
+```
+Deep3DFaceRecon_pytorch
+│
+└─── checkpoints
+ │
+ └───
+ │
+ └─── epoch_20.pth
+
+```
+
+### Test with custom images
+To reconstruct 3d faces from test images, organize the test image folder as follows:
+```
+Deep3DFaceRecon_pytorch
+│
+└───
+ │
+ └─── *.jpg/*.png
+ |
+ └─── detections
+ |
+ └─── *.txt
+```
+The \*.jpg/\*.png files are test images. The \*.txt files are detected 5 facial landmarks with a shape of 5x2, and have the same name as the corresponding images. Check [./datasets/examples](datasets/examples) for a reference.
+
+Then, run the test script:
+```
+# get reconstruction results of your custom images
+python test.py --name= --epoch=20 --img_folder=
+
+# get reconstruction results of example images
+python test.py --name= --epoch=20 --img_folder=./datasets/examples
+```
+**_Following [#108](https://github.com/sicxu/Deep3DFaceRecon_pytorch/issues/108), if you don't have OpenGL environment, you can simply add "--use_opengl False" to use CUDA context. Make sure you have updated the nvdiffrast to the latest version._**
+
+Results will be saved into ./checkpoints//results/, which contain the following files:
+| \*.png | A combination of cropped input image, reconstructed image, and visualization of projected landmarks.
+|:----|:-----------|
+| \*.obj | Reconstructed 3d face mesh with predicted color (texture+illumination) in the world coordinate space. Best viewed in Meshlab. |
+| \*.mat | Predicted 257-dimensional coefficients and 68 projected 2d facial landmarks. Best viewed in Matlab.
+
+## Training a model from scratch
+### Prepare prerequisite models
+1. We rely on [Arcface](https://github.com/deepinsight/insightface/tree/master/recognition/arcface_torch) to extract identity features for loss computation. Download the pre-trained model from Arcface using this [link](https://github.com/deepinsight/insightface/tree/master/recognition/arcface_torch#ms1mv3). By default, we use the resnet50 backbone ([ms1mv3_arcface_r50_fp16](https://onedrive.live.com/?authkey=%21AFZjr283nwZHqbA&id=4A83B6B633B029CC%215583&cid=4A83B6B633B029CC)), organize the download files into the following structure:
+```
+Deep3DFaceRecon_pytorch
+│
+└─── checkpoints
+ │
+ └─── recog_model
+ │
+ └─── ms1mv3_arcface_r50_fp16
+ |
+ └─── backbone.pth
+```
+2. We initialize R-Net using the weights trained on [ImageNet](https://image-net.org/). Download the weights provided by PyTorch using this [link](https://download.pytorch.org/models/resnet50-0676ba61.pth), and organize the file as the following structure:
+```
+Deep3DFaceRecon_pytorch
+│
+└─── checkpoints
+ │
+ └─── init_model
+ │
+ └─── resnet50-0676ba61.pth
+```
+3. We provide a landmark detector (tensorflow model) to extract 68 facial landmarks for loss computation. The detector is trained on [300WLP](http://www.cbsr.ia.ac.cn/users/xiangyuzhu/projects/3DDFA/main.htm), [LFW](http://vis-www.cs.umass.edu/lfw/), and [LS3D-W](https://www.adrianbulat.com/face-alignment) datasets. Download the trained model using this [link (google drive)](https://drive.google.com/file/d/1Jl1yy2v7lIJLTRVIpgg2wvxYITI8Dkmw/view?usp=sharing) and organize the file as follows:
+```
+Deep3DFaceRecon_pytorch
+│
+└─── checkpoints
+ │
+ └─── lm_model
+ │
+ └─── 68lm_detector.pb
+```
+### Data preparation
+1. To train a model with custom images,5 facial landmarks of each image are needed in advance for an image pre-alignment process. We recommend using [dlib](http://dlib.net/) or [MTCNN](https://github.com/ipazc/mtcnn) to detect these landmarks. Then, organize all files into the following structure:
+```
+Deep3DFaceRecon_pytorch
+│
+└─── datasets
+ │
+ └───
+ │
+ └─── *.png/*.jpg
+ |
+ └─── detections
+ |
+ └─── *.txt
+```
+The \*.txt files contain 5 facial landmarks with a shape of 5x2, and should have the same name with their corresponding images.
+
+2. Generate 68 landmarks and skin attention mask for images using the following script:
+```
+# preprocess training images
+python data_preparation.py --img_folder
+
+# alternatively, you can preprocess multiple image folders simultaneously
+python data_preparation.py --img_folder
+
+# preprocess validation images
+python data_preparation.py --img_folder --mode=val
+```
+The script will generate files of landmarks and skin masks, and save them into ./datasets/. In addition, it also generates a file containing the path of all training data into ./datalist which will then be used in the training script.
+
+### Train the face reconstruction network
+Run the following script to train a face reconstruction model using the pre-processed data:
+```
+# train with single GPU
+python train.py --name= --gpu_ids=0
+
+# train with multiple GPUs
+python train.py --name= --gpu_ids=0,1
+
+# train with other custom settings
+python train.py --name= --gpu_ids=0 --batch_size=32 --n_epochs=20
+```
+Training logs and model parameters will be saved into ./checkpoints/.
+
+By default, the script uses a batchsize of 32 and will train the model with 20 epochs. For reference, the pre-trained model in this repo is trained with the default setting on a image collection of 300k images. A single iteration takes 0.8~0.9s on a single Tesla M40 GPU. The total training process takes around two days.
+
+To use a trained model, see [Inference](https://github.com/sicxu/Deep3DFaceRecon_pytorch#inference-with-a-pre-trained-model) section.
+## Contact
+If you have any questions, please contact the paper authors.
+
+## Citation
+
+Please cite the following paper if this model helps your research:
+
+ @inproceedings{deng2019accurate,
+ title={Accurate 3D Face Reconstruction with Weakly-Supervised Learning: From Single Image to Image Set},
+ author={Yu Deng and Jiaolong Yang and Sicheng Xu and Dong Chen and Yunde Jia and Xin Tong},
+ booktitle={IEEE Computer Vision and Pattern Recognition Workshops},
+ year={2019}
+ }
+##
+The face images on this page are from the public [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) dataset released by MMLab, CUHK.
+
+Part of the code in this implementation takes [CUT](https://github.com/taesungp/contrastive-unpaired-translation) as a reference.
+
diff --git a/Deep3DFaceRecon_pytorch/data/__init__.py b/Deep3DFaceRecon_pytorch/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0aaef54bd893314aa8a7b97af8625659cd8d3bfe
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/data/__init__.py
@@ -0,0 +1,118 @@
+"""This package includes all the modules related to data loading and preprocessing
+
+ To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
+ You need to implement four functions:
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
+ -- <__len__>: return the size of dataset.
+ -- <__getitem__>: get a data point from data loader.
+ -- : (optionally) add dataset-specific options and set default options.
+
+Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
+See our template dataset class 'template_dataset.py' for more details.
+"""
+import importlib
+
+import numpy as np
+import torch.utils.data
+from data.base_dataset import BaseDataset
+
+
+def find_dataset_using_name(dataset_name):
+ """Import the module "data/[dataset_name]_dataset.py".
+
+ In the file, the class called DatasetNameDataset() will
+ be instantiated. It has to be a subclass of BaseDataset,
+ and it is case-insensitive.
+ """
+ dataset_filename = "data." + dataset_name + "_dataset"
+ datasetlib = importlib.import_module(dataset_filename)
+
+ dataset = None
+ target_dataset_name = dataset_name.replace("_", "") + "dataset"
+ for name, cls in datasetlib.__dict__.items():
+ if name.lower() == target_dataset_name.lower() and issubclass(cls, BaseDataset):
+ dataset = cls
+
+ if dataset is None:
+ raise NotImplementedError(
+ "In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase."
+ % (dataset_filename, target_dataset_name)
+ )
+
+ return dataset
+
+
+def get_option_setter(dataset_name):
+ """Return the static method of the dataset class."""
+ dataset_class = find_dataset_using_name(dataset_name)
+ return dataset_class.modify_commandline_options
+
+
+def create_dataset(opt, rank=0):
+ """Create a dataset given the option.
+
+ This function wraps the class CustomDatasetDataLoader.
+ This is the main interface between this package and 'train.py'/'test.py'
+
+ Example:
+ >>> from data import create_dataset
+ >>> dataset = create_dataset(opt)
+ """
+ data_loader = CustomDatasetDataLoader(opt, rank=rank)
+ dataset = data_loader.load_data()
+ return dataset
+
+
+class CustomDatasetDataLoader:
+ """Wrapper class of Dataset class that performs multi-threaded data loading"""
+
+ def __init__(self, opt, rank=0):
+ """Initialize this class
+
+ Step 1: create a dataset instance given the name [dataset_mode]
+ Step 2: create a multi-threaded data loader.
+ """
+ self.opt = opt
+ dataset_class = find_dataset_using_name(opt.dataset_mode)
+ self.dataset = dataset_class(opt)
+ self.sampler = None
+ print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__))
+ if opt.use_ddp and opt.isTrain:
+ world_size = opt.world_size
+ self.sampler = torch.utils.data.distributed.DistributedSampler(
+ self.dataset, num_replicas=world_size, rank=rank, shuffle=not opt.serial_batches
+ )
+ self.dataloader = torch.utils.data.DataLoader(
+ self.dataset,
+ sampler=self.sampler,
+ num_workers=int(opt.num_threads / world_size),
+ batch_size=int(opt.batch_size / world_size),
+ drop_last=True,
+ )
+ else:
+ self.dataloader = torch.utils.data.DataLoader(
+ self.dataset,
+ batch_size=opt.batch_size,
+ shuffle=(not opt.serial_batches) and opt.isTrain,
+ num_workers=int(opt.num_threads),
+ drop_last=True,
+ )
+
+ def set_epoch(self, epoch):
+ self.dataset.current_epoch = epoch
+ if self.sampler is not None:
+ self.sampler.set_epoch(epoch)
+
+ def load_data(self):
+ return self
+
+ def __len__(self):
+ """Return the number of data in the dataset"""
+ return min(len(self.dataset), self.opt.max_dataset_size)
+
+ def __iter__(self):
+ """Return a batch of data"""
+ for i, data in enumerate(self.dataloader):
+ if i * self.opt.batch_size >= self.opt.max_dataset_size:
+ break
+ yield data
diff --git a/Deep3DFaceRecon_pytorch/data/base_dataset.py b/Deep3DFaceRecon_pytorch/data/base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7d86153c9b17f7365b6480bb95ec272ff7fc095
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/data/base_dataset.py
@@ -0,0 +1,132 @@
+"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
+
+It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
+"""
+import random
+from abc import ABC
+from abc import abstractmethod
+
+import numpy as np
+import torch.utils.data as data
+import torchvision.transforms as transforms
+from PIL import Image
+
+
+class BaseDataset(data.Dataset, ABC):
+ """This class is an abstract base class (ABC) for datasets.
+
+ To create a subclass, you need to implement the following four functions:
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
+ -- <__len__>: return the size of dataset.
+ -- <__getitem__>: get a data point.
+ -- : (optionally) add dataset-specific options and set default options.
+ """
+
+ def __init__(self, opt):
+ """Initialize the class; save the options in the class
+
+ Parameters:
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
+ """
+ self.opt = opt
+ # self.root = opt.dataroot
+ self.current_epoch = 0
+
+ @staticmethod
+ def modify_commandline_options(parser, is_train):
+ """Add new dataset-specific options, and rewrite default values for existing options.
+
+ Parameters:
+ parser -- original option parser
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+ """
+ return parser
+
+ @abstractmethod
+ def __len__(self):
+ """Return the total number of images in the dataset."""
+ return 0
+
+ @abstractmethod
+ def __getitem__(self, index):
+ """Return a data point and its metadata information.
+
+ Parameters:
+ index - - a random integer for data indexing
+
+ Returns:
+ a dictionary of data with their names. It ususally contains the data itself and its metadata information.
+ """
+ pass
+
+
+def get_transform(grayscale=False):
+ transform_list = []
+ if grayscale:
+ transform_list.append(transforms.Grayscale(1))
+ transform_list += [transforms.ToTensor()]
+ return transforms.Compose(transform_list)
+
+
+def get_affine_mat(opt, size):
+ shift_x, shift_y, scale, rot_angle, flip = 0.0, 0.0, 1.0, 0.0, False
+ w, h = size
+
+ if "shift" in opt.preprocess:
+ shift_pixs = int(opt.shift_pixs)
+ shift_x = random.randint(-shift_pixs, shift_pixs)
+ shift_y = random.randint(-shift_pixs, shift_pixs)
+ if "scale" in opt.preprocess:
+ scale = 1 + opt.scale_delta * (2 * random.random() - 1)
+ if "rot" in opt.preprocess:
+ rot_angle = opt.rot_angle * (2 * random.random() - 1)
+ rot_rad = -rot_angle * np.pi / 180
+ if "flip" in opt.preprocess:
+ flip = random.random() > 0.5
+
+ shift_to_origin = np.array([1, 0, -w // 2, 0, 1, -h // 2, 0, 0, 1]).reshape([3, 3])
+ flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])
+ shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])
+ rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape(
+ [3, 3]
+ )
+ scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])
+ shift_to_center = np.array([1, 0, w // 2, 0, 1, h // 2, 0, 0, 1]).reshape([3, 3])
+
+ affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin
+ affine_inv = np.linalg.inv(affine)
+ return affine, affine_inv, flip
+
+
+def apply_img_affine(img, affine_inv, method=Image.Resampling.BICUBIC):
+ return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.Resampling.BICUBIC)
+
+
+def apply_lm_affine(landmark, affine, flip, size):
+ _, h = size
+ lm = landmark.copy()
+ lm[:, 1] = h - 1 - lm[:, 1]
+ lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)
+ lm = lm @ np.transpose(affine)
+ lm[:, :2] = lm[:, :2] / lm[:, 2:]
+ lm = lm[:, :2]
+ lm[:, 1] = h - 1 - lm[:, 1]
+ if flip:
+ lm_ = lm.copy()
+ lm_[:17] = lm[16::-1]
+ lm_[17:22] = lm[26:21:-1]
+ lm_[22:27] = lm[21:16:-1]
+ lm_[31:36] = lm[35:30:-1]
+ lm_[36:40] = lm[45:41:-1]
+ lm_[40:42] = lm[47:45:-1]
+ lm_[42:46] = lm[39:35:-1]
+ lm_[46:48] = lm[41:39:-1]
+ lm_[48:55] = lm[54:47:-1]
+ lm_[55:60] = lm[59:54:-1]
+ lm_[60:65] = lm[64:59:-1]
+ lm_[65:68] = lm[67:64:-1]
+ lm = lm_
+ return lm
diff --git a/Deep3DFaceRecon_pytorch/data/flist_dataset.py b/Deep3DFaceRecon_pytorch/data/flist_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..21265f361b46cbbce126b8541284d1bebf241e08
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/data/flist_dataset.py
@@ -0,0 +1,129 @@
+"""This script defines the custom dataset for Deep3DFaceRecon_pytorch
+"""
+import json
+import os.path
+import pickle
+import random
+
+import numpy as np
+import torch
+import util.util as util
+from data.base_dataset import apply_img_affine
+from data.base_dataset import apply_lm_affine
+from data.base_dataset import BaseDataset
+from data.base_dataset import get_affine_mat
+from data.base_dataset import get_transform
+from data.image_folder import make_dataset
+from PIL import Image
+from scipy.io import loadmat
+from scipy.io import savemat
+from util.load_mats import load_lm3d
+from util.preprocess import align_img
+from util.preprocess import estimate_norm
+
+
+def default_flist_reader(flist):
+ """
+ flist format: impath label\nimpath label\n ...(same to caffe's filelist)
+ """
+ imlist = []
+ with open(flist, "r") as rf:
+ for line in rf.readlines():
+ impath = line.strip()
+ imlist.append(impath)
+
+ return imlist
+
+
+def jason_flist_reader(flist):
+ with open(flist, "r") as fp:
+ info = json.load(fp)
+ return info
+
+
+def parse_label(label):
+ return torch.tensor(np.array(label).astype(np.float32))
+
+
+class FlistDataset(BaseDataset):
+ """
+ It requires one directories to host training images '/path/to/data/train'
+ You can train the model with the dataset flag '--dataroot /path/to/data'.
+ """
+
+ def __init__(self, opt):
+ """Initialize this dataset class.
+
+ Parameters:
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ """
+ BaseDataset.__init__(self, opt)
+
+ self.lm3d_std = load_lm3d(opt.bfm_folder)
+
+ msk_names = default_flist_reader(opt.flist)
+ self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
+
+ self.size = len(self.msk_paths)
+ self.opt = opt
+
+ self.name = "train" if opt.isTrain else "val"
+ if "_" in opt.flist:
+ self.name += "_" + opt.flist.split(os.sep)[-1].split("_")[0]
+
+ def __getitem__(self, index):
+ """Return a data point and its metadata information.
+
+ Parameters:
+ index (int) -- a random integer for data indexing
+
+ Returns a dictionary that contains A, B, A_paths and B_paths
+ img (tensor) -- an image in the input domain
+ msk (tensor) -- its corresponding attention mask
+ lm (tensor) -- its corresponding 3d landmarks
+ im_paths (str) -- image paths
+ aug_flag (bool) -- a flag used to tell whether its raw or augmented
+ """
+ msk_path = self.msk_paths[index % self.size] # make sure index is within then range
+ img_path = msk_path.replace("mask/", "")
+ lm_path = ".".join(msk_path.replace("mask", "landmarks").split(".")[:-1]) + ".txt"
+
+ raw_img = Image.open(img_path).convert("RGB")
+ raw_msk = Image.open(msk_path).convert("RGB")
+ raw_lm = np.loadtxt(lm_path).astype(np.float32)
+
+ _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
+
+ aug_flag = self.opt.use_aug and self.opt.isTrain
+ if aug_flag:
+ img, lm, msk = self._augmentation(img, lm, self.opt, msk)
+
+ _, H = img.size
+ M = estimate_norm(lm, H)
+ transform = get_transform()
+ img_tensor = transform(img)
+ msk_tensor = transform(msk)[:1, ...]
+ lm_tensor = parse_label(lm)
+ M_tensor = parse_label(M)
+
+ return {
+ "imgs": img_tensor,
+ "lms": lm_tensor,
+ "msks": msk_tensor,
+ "M": M_tensor,
+ "im_paths": img_path,
+ "aug_flag": aug_flag,
+ "dataset": self.name,
+ }
+
+ def _augmentation(self, img, lm, opt, msk=None):
+ affine, affine_inv, flip = get_affine_mat(opt, img.size)
+ img = apply_img_affine(img, affine_inv)
+ lm = apply_lm_affine(lm, affine, flip, img.size)
+ if msk is not None:
+ msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
+ return img, lm, msk
+
+ def __len__(self):
+ """Return the total number of images in the dataset."""
+ return self.size
diff --git a/Deep3DFaceRecon_pytorch/data/image_folder.py b/Deep3DFaceRecon_pytorch/data/image_folder.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a02dfe06948ddb90a76667d538f7e3ea47a33e8
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/data/image_folder.py
@@ -0,0 +1,77 @@
+"""A modified image folder class
+
+We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
+so that this class can load images from both current directory and its subdirectories.
+"""
+import os.path
+
+import numpy as np
+import torch.utils.data as data
+from PIL import Image
+
+IMG_EXTENSIONS = [
+ ".jpg",
+ ".JPG",
+ ".jpeg",
+ ".JPEG",
+ ".png",
+ ".PNG",
+ ".ppm",
+ ".PPM",
+ ".bmp",
+ ".BMP",
+ ".tif",
+ ".TIF",
+ ".tiff",
+ ".TIFF",
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def make_dataset(dir, max_dataset_size=float("inf")):
+ images = []
+ assert os.path.isdir(dir) or os.path.islink(dir), "%s is not a valid directory" % dir
+
+ for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
+ for fname in fnames:
+ if is_image_file(fname):
+ path = os.path.join(root, fname)
+ images.append(path)
+ return images[: min(max_dataset_size, len(images))]
+
+
+def default_loader(path):
+ return Image.open(path).convert("RGB")
+
+
+class ImageFolder(data.Dataset):
+ def __init__(self, root, transform=None, return_paths=False, loader=default_loader):
+ imgs = make_dataset(root)
+ if len(imgs) == 0:
+ raise (
+ RuntimeError(
+ "Found 0 images in: " + root + "\n" "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)
+ )
+ )
+
+ self.root = root
+ self.imgs = imgs
+ self.transform = transform
+ self.return_paths = return_paths
+ self.loader = loader
+
+ def __getitem__(self, index):
+ path = self.imgs[index]
+ img = self.loader(path)
+ if self.transform is not None:
+ img = self.transform(img)
+ if self.return_paths:
+ return img, path
+ else:
+ return img
+
+ def __len__(self):
+ return len(self.imgs)
diff --git a/Deep3DFaceRecon_pytorch/data/template_dataset.py b/Deep3DFaceRecon_pytorch/data/template_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..d776bde9c36396865e149922305cb2284942529b
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/data/template_dataset.py
@@ -0,0 +1,80 @@
+"""Dataset class template
+
+This module provides a template for users to implement custom datasets.
+You can specify '--dataset_mode template' to use this dataset.
+The class name should be consistent with both the filename and its dataset_mode option.
+The filename should be _dataset.py
+The class name should be Dataset.py
+You need to implement the following functions:
+ -- : Add dataset-specific options and rewrite default values for existing options.
+ -- <__init__>: Initialize this dataset class.
+ -- <__getitem__>: Return a data point and its metadata information.
+ -- <__len__>: Return the number of images.
+"""
+from data.base_dataset import BaseDataset
+from data.base_dataset import get_transform
+
+# from data.image_folder import make_dataset
+# from PIL import Image
+
+
+class TemplateDataset(BaseDataset):
+ """A template dataset class for you to implement custom datasets."""
+
+ @staticmethod
+ def modify_commandline_options(parser, is_train):
+ """Add new dataset-specific options, and rewrite default values for existing options.
+
+ Parameters:
+ parser -- original option parser
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+ """
+ parser.add_argument("--new_dataset_option", type=float, default=1.0, help="new dataset option")
+ parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
+ return parser
+
+ def __init__(self, opt):
+ """Initialize this dataset class.
+
+ Parameters:
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
+
+ A few things can be done here.
+ - save the options (have been done in BaseDataset)
+ - get image paths and meta information of the dataset.
+ - define the image transformation.
+ """
+ # save the option and dataset root
+ BaseDataset.__init__(self, opt)
+ # get the image paths of your dataset;
+ self.image_paths = (
+ []
+ ) # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
+ # define the default transform function. You can use ; You can also define your custom transform function
+ self.transform = get_transform(opt)
+
+ def __getitem__(self, index):
+ """Return a data point and its metadata information.
+
+ Parameters:
+ index -- a random integer for data indexing
+
+ Returns:
+ a dictionary of data with their names. It usually contains the data itself and its metadata information.
+
+ Step 1: get a random image path: e.g., path = self.image_paths[index]
+ Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
+ Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
+ Step 4: return a data point as a dictionary.
+ """
+ path = "temp" # needs to be a string
+ data_A = None # needs to be a tensor
+ data_B = None # needs to be a tensor
+ return {"data_A": data_A, "data_B": data_B, "path": path}
+
+ def __len__(self):
+ """Return the total number of images."""
+ return len(self.image_paths)
diff --git a/Deep3DFaceRecon_pytorch/data_preparation.py b/Deep3DFaceRecon_pytorch/data_preparation.py
new file mode 100644
index 0000000000000000000000000000000000000000..74b9361d45a56342b10530e015b4099b91bbf625
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/data_preparation.py
@@ -0,0 +1,57 @@
+"""This script is the data preparation script for Deep3DFaceRecon_pytorch
+"""
+import argparse
+import os
+import warnings
+
+import numpy as np
+from util.detect_lm68 import detect_68p
+from util.detect_lm68 import load_lm_graph
+from util.generate_list import check_list
+from util.generate_list import write_list
+from util.skin_mask import get_skin_mask
+
+warnings.filterwarnings("ignore")
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--data_root", type=str, default="datasets", help="root directory for training data")
+parser.add_argument("--img_folder", nargs="+", required=True, help="folders of training images")
+parser.add_argument("--mode", type=str, default="train", help="train or val")
+opt = parser.parse_args()
+
+os.environ["CUDA_VISIBLE_DEVICES"] = "0"
+
+
+def data_prepare(folder_list, mode):
+
+ lm_sess, input_op, output_op = load_lm_graph(
+ "./checkpoints/lm_model/68lm_detector.pb"
+ ) # load a tensorflow version 68-landmark detector
+
+ for img_folder in folder_list:
+ detect_68p(img_folder, lm_sess, input_op, output_op) # detect landmarks for images
+ get_skin_mask(img_folder) # generate skin attention mask for images
+
+ # create files that record path to all training data
+ msks_list = []
+ for img_folder in folder_list:
+ path = os.path.join(img_folder, "mask")
+ msks_list += [
+ "/".join([img_folder, "mask", i])
+ for i in sorted(os.listdir(path))
+ if "jpg" in i or "png" in i or "jpeg" in i or "PNG" in i
+ ]
+
+ imgs_list = [i.replace("mask/", "") for i in msks_list]
+ lms_list = [i.replace("mask", "landmarks") for i in msks_list]
+ lms_list = [".".join(i.split(".")[:-1]) + ".txt" for i in lms_list]
+
+ lms_list_final, imgs_list_final, msks_list_final = check_list(
+ lms_list, imgs_list, msks_list
+ ) # check if the path is valid
+ write_list(lms_list_final, imgs_list_final, msks_list_final, mode=mode) # save files
+
+
+if __name__ == "__main__":
+ print("Datasets:", opt.img_folder)
+ data_prepare([os.path.join(opt.data_root, folder) for folder in opt.img_folder], opt.mode)
diff --git a/Deep3DFaceRecon_pytorch/environment.yml b/Deep3DFaceRecon_pytorch/environment.yml
new file mode 100644
index 0000000000000000000000000000000000000000..c16d73cfd9d8e5488362aee9067ca5ef9897f60d
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/environment.yml
@@ -0,0 +1,24 @@
+name: deep3d_pytorch
+channels:
+ - pytorch
+ - conda-forge
+ - defaults
+dependencies:
+ - python=3.6
+ - pytorch=1.6.0
+ - torchvision=0.7.0
+ - numpy=1.18.1
+ - scikit-image=0.16.2
+ - scipy=1.4.1
+ - pillow=6.2.1
+ - pip=20.0.2
+ - ipython=7.13.0
+ - yaml=0.1.7
+ - pip:
+ - matplotlib==2.2.5
+ - opencv-python==3.4.9.33
+ - tensorboard==1.15.0
+ - tensorflow==1.15.0
+ - kornia==0.5.5
+ - dominate==2.6.0
+ - trimesh==3.9.20
\ No newline at end of file
diff --git a/Deep3DFaceRecon_pytorch/models/__init__.py b/Deep3DFaceRecon_pytorch/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4642bcbf549aa1354268ed7b6567d20eff6b1237
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/__init__.py
@@ -0,0 +1,69 @@
+"""This package contains modules related to objective functions, optimizations, and network architectures.
+
+To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
+You need to implement the following five functions:
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
+ -- : unpack data from dataset and apply preprocessing.
+ -- : produce intermediate results.
+ -- : calculate loss, gradients, and update network weights.
+ -- : (optionally) add model-specific options and set default options.
+
+In the function <__init__>, you need to define four lists:
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
+ -- self.model_names (str list): define networks used in our training.
+ -- self.visual_names (str list): specify the images that you want to display and save.
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
+
+Now you can use the model class by specifying flag '--model dummy'.
+See our template model class 'template_model.py' for more details.
+"""
+import importlib
+
+from Deep3DFaceRecon_pytorch.models.base_model import BaseModel
+
+
+def find_model_using_name(model_name):
+ """Import the module "models/[model_name]_model.py".
+
+ In the file, the class called DatasetNameModel() will
+ be instantiated. It has to be a subclass of BaseModel,
+ and it is case-insensitive.
+ """
+ model_filename = "models." + model_name + "_model"
+ modellib = importlib.import_module(model_filename)
+ model = None
+ target_model_name = model_name.replace("_", "") + "model"
+ for name, cls in modellib.__dict__.items():
+ if name.lower() == target_model_name.lower() and issubclass(cls, BaseModel):
+ model = cls
+
+ if model is None:
+ print(
+ "In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase."
+ % (model_filename, target_model_name)
+ )
+ exit(0)
+
+ return model
+
+
+def get_option_setter(model_name):
+ """Return the static method of the model class."""
+ model_class = find_model_using_name(model_name)
+ return model_class.modify_commandline_options
+
+
+def create_model(opt):
+ """Create a model given the option.
+
+ This function warps the class CustomDatasetDataLoader.
+ This is the main interface between this package and 'train.py'/'test.py'
+
+ Example:
+ >>> from models import create_model
+ >>> model = create_model(opt)
+ """
+ model = find_model_using_name(opt.model)
+ instance = model(opt)
+ print("model [%s] was created" % type(instance).__name__)
+ return instance
diff --git a/Deep3DFaceRecon_pytorch/models/__pycache__/__init__.cpython-310.pyc b/Deep3DFaceRecon_pytorch/models/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e0f5dd8cda00a950e69b90f7d3c8d0abb21858aa
Binary files /dev/null and b/Deep3DFaceRecon_pytorch/models/__pycache__/__init__.cpython-310.pyc differ
diff --git a/Deep3DFaceRecon_pytorch/models/__pycache__/base_model.cpython-310.pyc b/Deep3DFaceRecon_pytorch/models/__pycache__/base_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ee4ae4c8a8e3e14c47e405a901abd4769c61c076
Binary files /dev/null and b/Deep3DFaceRecon_pytorch/models/__pycache__/base_model.cpython-310.pyc differ
diff --git a/Deep3DFaceRecon_pytorch/models/__pycache__/bfm.cpython-310.pyc b/Deep3DFaceRecon_pytorch/models/__pycache__/bfm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f4d7594423e1484031c01ed95e86c8a5f4c463ef
Binary files /dev/null and b/Deep3DFaceRecon_pytorch/models/__pycache__/bfm.cpython-310.pyc differ
diff --git a/Deep3DFaceRecon_pytorch/models/__pycache__/networks.cpython-310.pyc b/Deep3DFaceRecon_pytorch/models/__pycache__/networks.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4bdbdad04a18832a3ad1c92ca85a3a69698465c5
Binary files /dev/null and b/Deep3DFaceRecon_pytorch/models/__pycache__/networks.cpython-310.pyc differ
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/README.md b/Deep3DFaceRecon_pytorch/models/arcface_torch/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8d391f63684dd1f47900dc6449a5e22fa25e3da3
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/README.md
@@ -0,0 +1,218 @@
+# Distributed Arcface Training in Pytorch
+
+The "arcface_torch" repository is the official implementation of the ArcFace algorithm. It supports distributed and sparse training with multiple distributed training examples, including several memory-saving techniques such as mixed precision training and gradient checkpointing. It also supports training for ViT models and datasets including WebFace42M and Glint360K, two of the largest open-source datasets. Additionally, the repository comes with a built-in tool for converting to ONNX format, making it easy to submit to MFR evaluation systems.
+
+[](https://paperswithcode.com/sota/face-verification-on-ijb-c?p=killing-two-birds-with-one-stone-efficient)
+[](https://paperswithcode.com/sota/face-verification-on-ijb-b?p=killing-two-birds-with-one-stone-efficient)
+[](https://paperswithcode.com/sota/face-verification-on-agedb-30?p=killing-two-birds-with-one-stone-efficient)
+[](https://paperswithcode.com/sota/face-verification-on-cfp-fp?p=killing-two-birds-with-one-stone-efficient)
+
+## Requirements
+
+To avail the latest features of PyTorch, we have upgraded to version 1.12.0.
+
+- Install [PyTorch](https://pytorch.org/get-started/previous-versions/) (torch>=1.12.0).
+- (Optional) Install [DALI](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/), our doc for [install_dali.md](docs/install_dali.md).
+- `pip install -r requirement.txt`.
+
+## How to Training
+
+To train a model, execute the `train.py` script with the path to the configuration files. The sample commands provided below demonstrate the process of conducting distributed training.
+
+### 1. To run on one GPU:
+
+```shell
+python train_v2.py configs/ms1mv3_r50_onegpu
+```
+
+Note:
+It is not recommended to use a single GPU for training, as this may result in longer training times and suboptimal performance. For best results, we suggest using multiple GPUs or a GPU cluster.
+
+
+### 2. To run on a machine with 8 GPUs:
+
+```shell
+torchrun --nproc_per_node=8 train.py configs/ms1mv3_r50
+```
+
+### 3. To run on 2 machines with 8 GPUs each:
+
+Node 0:
+
+```shell
+torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=12581 train.py configs/wf42m_pfc02_16gpus_r100
+```
+
+Node 1:
+
+```shell
+torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=12581 train.py configs/wf42m_pfc02_16gpus_r100
+```
+
+### 4. Run ViT-B on a machine with 24k batchsize:
+
+```shell
+torchrun --nproc_per_node=8 train_v2.py configs/wf42m_pfc03_40epoch_8gpu_vit_b
+```
+
+
+## Download Datasets or Prepare Datasets
+- [MS1MV2](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_#ms1m-arcface-85k-ids58m-images-57) (87k IDs, 5.8M images)
+- [MS1MV3](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_#ms1m-retinaface) (93k IDs, 5.2M images)
+- [Glint360K](https://github.com/deepinsight/insightface/tree/master/recognition/partial_fc#4-download) (360k IDs, 17.1M images)
+- [WebFace42M](docs/prepare_webface42m.md) (2M IDs, 42.5M images)
+- [Your Dataset, Click Here!](docs/prepare_custom_dataset.md)
+
+Note:
+If you want to use DALI for data reading, please use the script 'scripts/shuffle_rec.py' to shuffle the InsightFace style rec before using it.
+Example:
+
+`python scripts/shuffle_rec.py ms1m-retinaface-t1`
+
+You will get the "shuffled_ms1m-retinaface-t1" folder, where the samples in the "train.rec" file are shuffled.
+
+
+## Model Zoo
+
+- The models are available for non-commercial research purposes only.
+- All models can be found in here.
+- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
+- [OneDrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
+
+### Performance on IJB-C and [**ICCV2021-MFR**](https://github.com/deepinsight/insightface/blob/master/challenges/mfr/README.md)
+
+ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face
+recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities.
+As the result, we can evaluate the FAIR performance for different algorithms.
+
+For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The
+globalised multi-racial testset contains 242,143 identities and 1,624,305 images.
+
+
+#### 1. Training on Single-Host GPU
+
+| Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | log |
+|:---------------|:--------------------|:------------|:------------|:------------|:------------------------------------------------------------------------------------------------------------------------------------|
+| MS1MV2 | mobilefacenet-0.45G | 62.07 | 93.61 | 90.28 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv2_mbf/training.log) |
+| MS1MV2 | r50 | 75.13 | 95.97 | 94.07 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv2_r50/training.log) |
+| MS1MV2 | r100 | 78.12 | 96.37 | 94.27 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv2_r100/training.log) |
+| MS1MV3 | mobilefacenet-0.45G | 63.78 | 94.23 | 91.33 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_mbf/training.log) |
+| MS1MV3 | r50 | 79.14 | 96.37 | 94.47 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_r50/training.log) |
+| MS1MV3 | r100 | 81.97 | 96.85 | 95.02 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_r100/training.log) |
+| Glint360K | mobilefacenet-0.45G | 70.18 | 95.04 | 92.62 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_mbf/training.log) |
+| Glint360K | r50 | 86.34 | 97.16 | 95.81 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_r50/training.log) |
+| Glint360k | r100 | 89.52 | 97.55 | 96.38 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_r100/training.log) |
+| WF4M | r100 | 89.87 | 97.19 | 95.48 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf4m_r100/training.log) |
+| WF12M-PFC-0.2 | r100 | 94.75 | 97.60 | 95.90 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf12m_pfc02_r100/training.log) |
+| WF12M-PFC-0.3 | r100 | 94.71 | 97.64 | 96.01 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf12m_pfc03_r100/training.log) |
+| WF12M | r100 | 94.69 | 97.59 | 95.97 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf12m_r100/training.log) |
+| WF42M-PFC-0.2 | r100 | 96.27 | 97.70 | 96.31 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf42m_pfc02_r100/training.log) |
+| WF42M-PFC-0.2 | ViT-T-1.5G | 92.04 | 97.27 | 95.68 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf42m_pfc02_40epoch_8gpu_vit_t/training.log) |
+| WF42M-PFC-0.3 | ViT-B-11G | 97.16 | 97.91 | 97.05 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_b_8gpu/training.log) |
+
+#### 2. Training on Multi-Host GPU
+
+| Datasets | Backbone(bs*gpus) | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | Throughout | log |
+|:-----------------|:------------------|:------------|:------------|:------------|:-----------|:-------------------------------------------------------------------------------------------------------------------------------------------|
+| WF42M-PFC-0.2 | r50(512*8) | 93.83 | 97.53 | 96.16 | ~5900 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r50_bs4k_pfc02/training.log) |
+| WF42M-PFC-0.2 | r50(512*16) | 93.96 | 97.46 | 96.12 | ~11000 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r50_lr01_pfc02_bs8k_16gpus/training.log) |
+| WF42M-PFC-0.2 | r50(128*32) | 94.04 | 97.48 | 95.94 | ~17000 | click me |
+| WF42M-PFC-0.2 | r100(128*16) | 96.28 | 97.80 | 96.57 | ~5200 | click me |
+| WF42M-PFC-0.2 | r100(256*16) | 96.69 | 97.85 | 96.63 | ~5200 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r100_bs4k_pfc02/training.log) |
+| WF42M-PFC-0.0018 | r100(512*32) | 93.08 | 97.51 | 95.88 | ~10000 | click me |
+| WF42M-PFC-0.2 | r100(128*32) | 96.57 | 97.83 | 96.50 | ~9800 | click me |
+
+`r100(128*32)` means backbone is r100, batchsize per gpu is 128, the number of gpus is 32.
+
+
+
+#### 3. ViT For Face Recognition
+
+| Datasets | Backbone(bs) | FLOPs | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | Throughout | log |
+|:--------------|:--------------|:------|:------------|:------------|:------------|:-----------|:-----------------------------------------------------------------------------------------------------------------------------|
+| WF42M-PFC-0.3 | r18(128*32) | 2.6 | 79.13 | 95.77 | 93.36 | - | click me |
+| WF42M-PFC-0.3 | r50(128*32) | 6.3 | 94.03 | 97.48 | 95.94 | - | click me |
+| WF42M-PFC-0.3 | r100(128*32) | 12.1 | 96.69 | 97.82 | 96.45 | - | click me |
+| WF42M-PFC-0.3 | r200(128*32) | 23.5 | 97.70 | 97.97 | 96.93 | - | click me |
+| WF42M-PFC-0.3 | VIT-T(384*64) | 1.5 | 92.24 | 97.31 | 95.97 | ~35000 | click me |
+| WF42M-PFC-0.3 | VIT-S(384*64) | 5.7 | 95.87 | 97.73 | 96.57 | ~25000 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_s_64gpu/training.log) |
+| WF42M-PFC-0.3 | VIT-B(384*64) | 11.4 | 97.42 | 97.90 | 97.04 | ~13800 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_b_64gpu/training.log) |
+| WF42M-PFC-0.3 | VIT-L(384*64) | 25.3 | 97.85 | 98.00 | 97.23 | ~9406 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_l_64gpu/training.log) |
+
+`WF42M` means WebFace42M, `PFC-0.3` means negivate class centers sample rate is 0.3.
+
+#### 4. Noisy Datasets
+
+| Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | log |
+|:-------------------------|:---------|:------------|:------------|:------------|:---------|
+| WF12M-Flip(40%) | r50 | 43.87 | 88.35 | 80.78 | click me |
+| WF12M-Flip(40%)-PFC-0.1* | r50 | 80.20 | 96.11 | 93.79 | click me |
+| WF12M-Conflict | r50 | 79.93 | 95.30 | 91.56 | click me |
+| WF12M-Conflict-PFC-0.3* | r50 | 91.68 | 97.28 | 95.75 | click me |
+
+`WF12M` means WebFace12M, `+PFC-0.1*` denotes additional abnormal inter-class filtering.
+
+
+
+## Speed Benchmark
+
+
+
+**Arcface-Torch** is an efficient tool for training large-scale face recognition training sets. When the number of classes in the training sets exceeds one million, the partial FC sampling strategy maintains the same accuracy while providing several times faster training performance and lower GPU memory utilization. The partial FC is a sparse variant of the model parallel architecture for large-scale face recognition, utilizing a sparse softmax that dynamically samples a subset of class centers for each training batch. During each iteration, only a sparse portion of the parameters are updated, leading to a significant reduction in GPU memory requirements and computational demands. With the partial FC approach, it is possible to train sets with up to 29 million identities, the largest to date. Furthermore, the partial FC method supports multi-machine distributed training and mixed precision training.
+
+
+
+More details see
+[speed_benchmark.md](docs/speed_benchmark.md) in docs.
+
+> 1. Training Speed of Various Parallel Techniques (Samples per Second) on a Tesla V100 32GB x 8 System (Higher is Optimal)
+
+`-` means training failed because of gpu memory limitations.
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+|:--------------------------------|:--------------|:---------------|:---------------|
+| 125000 | 4681 | 4824 | 5004 |
+| 1400000 | **1672** | 3043 | 4738 |
+| 5500000 | **-** | **1389** | 3975 |
+| 8000000 | **-** | **-** | 3565 |
+| 16000000 | **-** | **-** | 2679 |
+| 29000000 | **-** | **-** | **1855** |
+
+> 2. GPU Memory Utilization of Various Parallel Techniques (MB per GPU) on a Tesla V100 32GB x 8 System (Lower is Optimal)
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+|:--------------------------------|:--------------|:---------------|:---------------|
+| 125000 | 7358 | 5306 | 4868 |
+| 1400000 | 32252 | 11178 | 6056 |
+| 5500000 | **-** | 32188 | 9854 |
+| 8000000 | **-** | **-** | 12310 |
+| 16000000 | **-** | **-** | 19950 |
+| 29000000 | **-** | **-** | 32324 |
+
+
+## Citations
+
+```
+@inproceedings{deng2019arcface,
+ title={Arcface: Additive angular margin loss for deep face recognition},
+ author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
+ booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={4690--4699},
+ year={2019}
+}
+@inproceedings{An_2022_CVPR,
+ author={An, Xiang and Deng, Jiankang and Guo, Jia and Feng, Ziyong and Zhu, XuHan and Yang, Jing and Liu, Tongliang},
+ title={Killing Two Birds With One Stone: Efficient and Robust Training of Face Recognition CNNs by Partial FC},
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
+ month={June},
+ year={2022},
+ pages={4042-4051}
+}
+@inproceedings{zhu2021webface260m,
+ title={Webface260m: A benchmark unveiling the power of million-scale deep face recognition},
+ author={Zhu, Zheng and Huang, Guan and Deng, Jiankang and Ye, Yun and Huang, Junjie and Chen, Xinze and Zhu, Jiagang and Yang, Tian and Lu, Jiwen and Du, Dalong and Zhou, Jie},
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
+ pages={10492--10502},
+ year={2021}
+}
+```
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__init__.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..94288c3af835e3513ddc70eb4cfb7f7e86852e3f
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__init__.py
@@ -0,0 +1,157 @@
+from .iresnet import iresnet100
+from .iresnet import iresnet18
+from .iresnet import iresnet200
+from .iresnet import iresnet34
+from .iresnet import iresnet50
+from .mobilefacenet import get_mbf
+
+
+def get_model(name, **kwargs):
+ # resnet
+ if name == "r18":
+ return iresnet18(False, **kwargs)
+ elif name == "r34":
+ return iresnet34(False, **kwargs)
+ elif name == "r50":
+ return iresnet50(False, **kwargs)
+ elif name == "r100":
+ return iresnet100(False, **kwargs)
+ elif name == "r200":
+ return iresnet200(False, **kwargs)
+ elif name == "r2060":
+ from .iresnet2060 import iresnet2060
+
+ return iresnet2060(False, **kwargs)
+
+ elif name == "mbf":
+ fp16 = kwargs.get("fp16", False)
+ num_features = kwargs.get("num_features", 512)
+ return get_mbf(fp16=fp16, num_features=num_features)
+
+ elif name == "mbf_large":
+ from .mobilefacenet import get_mbf_large
+
+ fp16 = kwargs.get("fp16", False)
+ num_features = kwargs.get("num_features", 512)
+ return get_mbf_large(fp16=fp16, num_features=num_features)
+
+ elif name == "vit_t":
+ num_features = kwargs.get("num_features", 512)
+ from .vit import VisionTransformer
+
+ return VisionTransformer(
+ img_size=112,
+ patch_size=9,
+ num_classes=num_features,
+ embed_dim=256,
+ depth=12,
+ num_heads=8,
+ drop_path_rate=0.1,
+ norm_layer="ln",
+ mask_ratio=0.1,
+ )
+
+ elif name == "vit_t_dp005_mask0": # For WebFace42M
+ num_features = kwargs.get("num_features", 512)
+ from .vit import VisionTransformer
+
+ return VisionTransformer(
+ img_size=112,
+ patch_size=9,
+ num_classes=num_features,
+ embed_dim=256,
+ depth=12,
+ num_heads=8,
+ drop_path_rate=0.05,
+ norm_layer="ln",
+ mask_ratio=0.0,
+ )
+
+ elif name == "vit_s":
+ num_features = kwargs.get("num_features", 512)
+ from .vit import VisionTransformer
+
+ return VisionTransformer(
+ img_size=112,
+ patch_size=9,
+ num_classes=num_features,
+ embed_dim=512,
+ depth=12,
+ num_heads=8,
+ drop_path_rate=0.1,
+ norm_layer="ln",
+ mask_ratio=0.1,
+ )
+
+ elif name == "vit_s_dp005_mask_0": # For WebFace42M
+ num_features = kwargs.get("num_features", 512)
+ from .vit import VisionTransformer
+
+ return VisionTransformer(
+ img_size=112,
+ patch_size=9,
+ num_classes=num_features,
+ embed_dim=512,
+ depth=12,
+ num_heads=8,
+ drop_path_rate=0.05,
+ norm_layer="ln",
+ mask_ratio=0.0,
+ )
+
+ elif name == "vit_b":
+ # this is a feature
+ num_features = kwargs.get("num_features", 512)
+ from .vit import VisionTransformer
+
+ return VisionTransformer(
+ img_size=112,
+ patch_size=9,
+ num_classes=num_features,
+ embed_dim=512,
+ depth=24,
+ num_heads=8,
+ drop_path_rate=0.1,
+ norm_layer="ln",
+ mask_ratio=0.1,
+ using_checkpoint=True,
+ )
+
+ elif name == "vit_b_dp005_mask_005": # For WebFace42M
+ # this is a feature
+ num_features = kwargs.get("num_features", 512)
+ from .vit import VisionTransformer
+
+ return VisionTransformer(
+ img_size=112,
+ patch_size=9,
+ num_classes=num_features,
+ embed_dim=512,
+ depth=24,
+ num_heads=8,
+ drop_path_rate=0.05,
+ norm_layer="ln",
+ mask_ratio=0.05,
+ using_checkpoint=True,
+ )
+
+ elif name == "vit_l_dp005_mask_005": # For WebFace42M
+ # this is a feature
+ num_features = kwargs.get("num_features", 512)
+ from .vit import VisionTransformer
+
+ return VisionTransformer(
+ img_size=112,
+ patch_size=9,
+ num_classes=num_features,
+ embed_dim=768,
+ depth=24,
+ num_heads=8,
+ drop_path_rate=0.05,
+ norm_layer="ln",
+ mask_ratio=0.05,
+ using_checkpoint=True,
+ )
+
+ else:
+ raise ValueError()
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__pycache__/__init__.cpython-310.pyc b/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4b99df3bd964054c194464e21115ed2fb7d8bd6e
Binary files /dev/null and b/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__pycache__/__init__.cpython-310.pyc differ
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__pycache__/iresnet.cpython-310.pyc b/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__pycache__/iresnet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ebb76456ff47c5c3112e619464cdd90b67620e57
Binary files /dev/null and b/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__pycache__/iresnet.cpython-310.pyc differ
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-310.pyc b/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0538adc4517e84e544fd0057d412678277543d93
Binary files /dev/null and b/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-310.pyc differ
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/iresnet.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/iresnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c3eea3ac6c1c92a9a92dab3518630cb5039bdf8
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/iresnet.py
@@ -0,0 +1,198 @@
+import torch
+from torch import nn
+from torch.utils.checkpoint import checkpoint
+
+__all__ = ["iresnet18", "iresnet34", "iresnet50", "iresnet100", "iresnet200"]
+using_ckpt = False
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ groups=groups,
+ bias=False,
+ dilation=dilation,
+ )
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class IBasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1):
+ super(IBasicBlock, self).__init__()
+ if groups != 1 or base_width != 64:
+ raise ValueError("BasicBlock only supports groups=1 and base_width=64")
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ self.bn1 = nn.BatchNorm2d(
+ inplanes,
+ eps=1e-05,
+ )
+ self.conv1 = conv3x3(inplanes, planes)
+ self.bn2 = nn.BatchNorm2d(
+ planes,
+ eps=1e-05,
+ )
+ self.prelu = nn.PReLU(planes)
+ self.conv2 = conv3x3(planes, planes, stride)
+ self.bn3 = nn.BatchNorm2d(
+ planes,
+ eps=1e-05,
+ )
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward_impl(self, x):
+ identity = x
+ out = self.bn1(x)
+ out = self.conv1(out)
+ out = self.bn2(out)
+ out = self.prelu(out)
+ out = self.conv2(out)
+ out = self.bn3(out)
+ if self.downsample is not None:
+ identity = self.downsample(x)
+ out += identity
+ return out
+
+ def forward(self, x):
+ if self.training and using_ckpt:
+ return checkpoint(self.forward_impl, x)
+ else:
+ return self.forward_impl(x)
+
+
+class IResNet(nn.Module):
+ fc_scale = 7 * 7
+
+ def __init__(
+ self,
+ block,
+ layers,
+ dropout=0,
+ num_features=512,
+ zero_init_residual=False,
+ groups=1,
+ width_per_group=64,
+ replace_stride_with_dilation=None,
+ fp16=False,
+ ):
+ super(IResNet, self).__init__()
+ self.extra_gflops = 0.0
+ self.fp16 = fp16
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError(
+ "replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
+ )
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
+ self.prelu = nn.PReLU(self.inplanes)
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
+ self.bn2 = nn.BatchNorm2d(
+ 512 * block.expansion,
+ eps=1e-05,
+ )
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
+ nn.init.constant_(self.features.weight, 1.0)
+ self.features.weight.requires_grad = False
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.normal_(m.weight, 0, 0.1)
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, IBasicBlock):
+ nn.init.constant_(m.bn2.weight, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ nn.BatchNorm2d(
+ planes * block.expansion,
+ eps=1e-05,
+ ),
+ )
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(
+ block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation)
+ )
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ with torch.cuda.amp.autocast(self.fp16):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.prelu(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.bn2(x)
+ x = torch.flatten(x, 1)
+ x = self.dropout(x)
+ x = self.fc(x.float() if self.fp16 else x)
+ x = self.features(x)
+ return x
+
+
+def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
+ model = IResNet(block, layers, **kwargs)
+ if pretrained:
+ raise ValueError()
+ return model
+
+
+def iresnet18(pretrained=False, progress=True, **kwargs):
+ return _iresnet("iresnet18", IBasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
+
+
+def iresnet34(pretrained=False, progress=True, **kwargs):
+ return _iresnet("iresnet34", IBasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs)
+
+
+def iresnet50(pretrained=False, progress=True, **kwargs):
+ return _iresnet("iresnet50", IBasicBlock, [3, 4, 14, 3], pretrained, progress, **kwargs)
+
+
+def iresnet100(pretrained=False, progress=True, **kwargs):
+ return _iresnet("iresnet100", IBasicBlock, [3, 13, 30, 3], pretrained, progress, **kwargs)
+
+
+def iresnet200(pretrained=False, progress=True, **kwargs):
+ return _iresnet("iresnet200", IBasicBlock, [6, 26, 60, 6], pretrained, progress, **kwargs)
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/iresnet2060.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/iresnet2060.py
new file mode 100644
index 0000000000000000000000000000000000000000..468b00201a06e33653f1e0aa738668cf4ee68fb0
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/iresnet2060.py
@@ -0,0 +1,182 @@
+import torch
+from torch import nn
+
+assert torch.__version__ >= "1.8.1"
+from torch.utils.checkpoint import checkpoint_sequential
+
+__all__ = ["iresnet2060"]
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ groups=groups,
+ bias=False,
+ dilation=dilation,
+ )
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class IBasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1):
+ super(IBasicBlock, self).__init__()
+ if groups != 1 or base_width != 64:
+ raise ValueError("BasicBlock only supports groups=1 and base_width=64")
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ self.bn1 = nn.BatchNorm2d(
+ inplanes,
+ eps=1e-05,
+ )
+ self.conv1 = conv3x3(inplanes, planes)
+ self.bn2 = nn.BatchNorm2d(
+ planes,
+ eps=1e-05,
+ )
+ self.prelu = nn.PReLU(planes)
+ self.conv2 = conv3x3(planes, planes, stride)
+ self.bn3 = nn.BatchNorm2d(
+ planes,
+ eps=1e-05,
+ )
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+ out = self.bn1(x)
+ out = self.conv1(out)
+ out = self.bn2(out)
+ out = self.prelu(out)
+ out = self.conv2(out)
+ out = self.bn3(out)
+ if self.downsample is not None:
+ identity = self.downsample(x)
+ out += identity
+ return out
+
+
+class IResNet(nn.Module):
+ fc_scale = 7 * 7
+
+ def __init__(
+ self,
+ block,
+ layers,
+ dropout=0,
+ num_features=512,
+ zero_init_residual=False,
+ groups=1,
+ width_per_group=64,
+ replace_stride_with_dilation=None,
+ fp16=False,
+ ):
+ super(IResNet, self).__init__()
+ self.fp16 = fp16
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError(
+ "replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
+ )
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
+ self.prelu = nn.PReLU(self.inplanes)
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
+ self.bn2 = nn.BatchNorm2d(
+ 512 * block.expansion,
+ eps=1e-05,
+ )
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
+ nn.init.constant_(self.features.weight, 1.0)
+ self.features.weight.requires_grad = False
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.normal_(m.weight, 0, 0.1)
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, IBasicBlock):
+ nn.init.constant_(m.bn2.weight, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ nn.BatchNorm2d(
+ planes * block.expansion,
+ eps=1e-05,
+ ),
+ )
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(
+ block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation)
+ )
+
+ return nn.Sequential(*layers)
+
+ def checkpoint(self, func, num_seg, x):
+ if self.training:
+ return checkpoint_sequential(func, num_seg, x)
+ else:
+ return func(x)
+
+ def forward(self, x):
+ with torch.cuda.amp.autocast(self.fp16):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.prelu(x)
+ x = self.layer1(x)
+ x = self.checkpoint(self.layer2, 20, x)
+ x = self.checkpoint(self.layer3, 100, x)
+ x = self.layer4(x)
+ x = self.bn2(x)
+ x = torch.flatten(x, 1)
+ x = self.dropout(x)
+ x = self.fc(x.float() if self.fp16 else x)
+ x = self.features(x)
+ return x
+
+
+def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
+ model = IResNet(block, layers, **kwargs)
+ if pretrained:
+ raise ValueError()
+ return model
+
+
+def iresnet2060(pretrained=False, progress=True, **kwargs):
+ return _iresnet("iresnet2060", IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/mobilefacenet.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/mobilefacenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..e36953e8172aa7cdbd58decbf1414c061459526d
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/mobilefacenet.py
@@ -0,0 +1,160 @@
+"""
+Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
+Original author cavalleria
+"""
+import torch
+import torch.nn as nn
+from torch.nn import BatchNorm1d
+from torch.nn import BatchNorm2d
+from torch.nn import Conv2d
+from torch.nn import Linear
+from torch.nn import Module
+from torch.nn import PReLU
+from torch.nn import Sequential
+
+
+class Flatten(Module):
+ def forward(self, x):
+ return x.view(x.size(0), -1)
+
+
+class ConvBlock(Module):
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
+ super(ConvBlock, self).__init__()
+ self.layers = nn.Sequential(
+ Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
+ BatchNorm2d(num_features=out_c),
+ PReLU(num_parameters=out_c),
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class LinearBlock(Module):
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
+ super(LinearBlock, self).__init__()
+ self.layers = nn.Sequential(
+ Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False), BatchNorm2d(num_features=out_c)
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class DepthWise(Module):
+ def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
+ super(DepthWise, self).__init__()
+ self.residual = residual
+ self.layers = nn.Sequential(
+ ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
+ ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
+ LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
+ )
+
+ def forward(self, x):
+ short_cut = None
+ if self.residual:
+ short_cut = x
+ x = self.layers(x)
+ if self.residual:
+ output = short_cut + x
+ else:
+ output = x
+ return output
+
+
+class Residual(Module):
+ def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
+ super(Residual, self).__init__()
+ modules = []
+ for _ in range(num_block):
+ modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
+ self.layers = Sequential(*modules)
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class GDC(Module):
+ def __init__(self, embedding_size):
+ super(GDC, self).__init__()
+ self.layers = nn.Sequential(
+ LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
+ Flatten(),
+ Linear(512, embedding_size, bias=False),
+ BatchNorm1d(embedding_size),
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class MobileFaceNet(Module):
+ def __init__(self, fp16=False, num_features=512, blocks=(1, 4, 6, 2), scale=2):
+ super(MobileFaceNet, self).__init__()
+ self.scale = scale
+ self.fp16 = fp16
+ self.layers = nn.ModuleList()
+ self.layers.append(ConvBlock(3, 64 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)))
+ if blocks[0] == 1:
+ self.layers.append(
+ ConvBlock(64 * self.scale, 64 * self.scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
+ )
+ else:
+ self.layers.append(
+ Residual(
+ 64 * self.scale, num_block=blocks[0], groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)
+ ),
+ )
+
+ self.layers.extend(
+ [
+ DepthWise(64 * self.scale, 64 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
+ Residual(
+ 64 * self.scale, num_block=blocks[1], groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)
+ ),
+ DepthWise(64 * self.scale, 128 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
+ Residual(
+ 128 * self.scale, num_block=blocks[2], groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)
+ ),
+ DepthWise(128 * self.scale, 128 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
+ Residual(
+ 128 * self.scale, num_block=blocks[3], groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)
+ ),
+ ]
+ )
+
+ self.conv_sep = ConvBlock(128 * self.scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
+ self.features = GDC(num_features)
+ self._initialize_weights()
+
+ def _initialize_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Linear):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, x):
+ with torch.cuda.amp.autocast(self.fp16):
+ for func in self.layers:
+ x = func(x)
+ x = self.conv_sep(x.float() if self.fp16 else x)
+ x = self.features(x)
+ return x
+
+
+def get_mbf(fp16, num_features, blocks=(1, 4, 6, 2), scale=2):
+ return MobileFaceNet(fp16, num_features, blocks, scale=scale)
+
+
+def get_mbf_large(fp16, num_features, blocks=(2, 8, 12, 4), scale=4):
+ return MobileFaceNet(fp16, num_features, blocks, scale=scale)
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/vit.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..59bb28b60309c0e904adb0830ffcf265d9f1dec6
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/vit.py
@@ -0,0 +1,302 @@
+from typing import Callable
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from timm.models.layers import DropPath
+from timm.models.layers import to_2tuple
+from timm.models.layers import trunc_normal_
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.0):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class VITBatchNorm(nn.Module):
+ def __init__(self, num_features):
+ super().__init__()
+ self.num_features = num_features
+ self.bn = nn.BatchNorm1d(num_features=num_features)
+
+ def forward(self, x):
+ return self.bn(x)
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_scale: Optional[None] = None,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+
+ with torch.cuda.amp.autocast(True):
+ batch_size, num_token, embed_dim = x.shape
+ # qkv is [3,batch_size,num_heads,num_token, embed_dim//num_heads]
+ qkv = (
+ self.qkv(x)
+ .reshape(batch_size, num_token, 3, self.num_heads, embed_dim // self.num_heads)
+ .permute(2, 0, 3, 1, 4)
+ )
+ with torch.cuda.amp.autocast(False):
+ q, k, v = qkv[0].float(), qkv[1].float(), qkv[2].float()
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2).reshape(batch_size, num_token, embed_dim)
+ with torch.cuda.amp.autocast(True):
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ num_patches: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ qk_scale: Optional[None] = None,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ drop_path: float = 0.0,
+ act_layer: Callable = nn.ReLU6,
+ norm_layer: str = "ln",
+ patch_n: int = 144,
+ ):
+ super().__init__()
+
+ if norm_layer == "bn":
+ self.norm1 = VITBatchNorm(num_features=num_patches)
+ self.norm2 = VITBatchNorm(num_features=num_patches)
+ elif norm_layer == "ln":
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop
+ )
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+ self.extra_gflops = (num_heads * patch_n * (dim // num_heads) * patch_n * 2) / (1000**3)
+
+ def forward(self, x):
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ with torch.cuda.amp.autocast(True):
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Module):
+ def __init__(self, img_size=108, patch_size=9, in_channels=3, embed_dim=768):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+ self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x):
+ batch_size, channels, height, width = x.shape
+ assert (
+ height == self.img_size[0] and width == self.img_size[1]
+ ), f"Input image size ({height}*{width}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ return x
+
+
+class VisionTransformer(nn.Module):
+ """Vision Transformer with support for patch or hybrid CNN input stage"""
+
+ def __init__(
+ self,
+ img_size: int = 112,
+ patch_size: int = 16,
+ in_channels: int = 3,
+ num_classes: int = 1000,
+ embed_dim: int = 768,
+ depth: int = 12,
+ num_heads: int = 12,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ qk_scale: Optional[None] = None,
+ drop_rate: float = 0.0,
+ attn_drop_rate: float = 0.0,
+ drop_path_rate: float = 0.0,
+ hybrid_backbone: Optional[None] = None,
+ norm_layer: str = "ln",
+ mask_ratio=0.1,
+ using_checkpoint=False,
+ ):
+ super().__init__()
+ self.num_classes = num_classes
+ # num_features for consistency with other models
+ self.num_features = self.embed_dim = embed_dim
+
+ if hybrid_backbone is not None:
+ raise ValueError
+ else:
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim
+ )
+ self.mask_ratio = mask_ratio
+ self.using_checkpoint = using_checkpoint
+ num_patches = self.patch_embed.num_patches
+ self.num_patches = num_patches
+
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth decay rule
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
+ patch_n = (img_size // patch_size) ** 2
+ self.blocks = nn.ModuleList(
+ [
+ Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ num_patches=num_patches,
+ patch_n=patch_n,
+ )
+ for i in range(depth)
+ ]
+ )
+ self.extra_gflops = 0.0
+ for _block in self.blocks:
+ self.extra_gflops += _block.extra_gflops
+
+ if norm_layer == "ln":
+ self.norm = nn.LayerNorm(embed_dim)
+ elif norm_layer == "bn":
+ self.norm = VITBatchNorm(self.num_patches)
+
+ # features head
+ self.feature = nn.Sequential(
+ nn.Linear(in_features=embed_dim * num_patches, out_features=embed_dim, bias=False),
+ nn.BatchNorm1d(num_features=embed_dim, eps=2e-5),
+ nn.Linear(in_features=embed_dim, out_features=num_classes, bias=False),
+ nn.BatchNorm1d(num_features=num_classes, eps=2e-5),
+ )
+
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ torch.nn.init.normal_(self.mask_token, std=0.02)
+ trunc_normal_(self.pos_embed, std=0.02)
+ # trunc_normal_(self.cls_token, std=.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {"pos_embed", "cls_token"}
+
+ def get_classifier(self):
+ return self.head
+
+ def random_masking(self, x, mask_ratio=0.1):
+ """
+ Perform per-sample random masking by per-sample shuffling.
+ Per-sample shuffling is done by argsort random noise.
+ x: [N, L, D], sequence
+ """
+ N, L, D = x.size() # batch, length, dim
+ len_keep = int(L * (1 - mask_ratio))
+
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
+
+ # sort noise for each sample
+ # ascend: small is keep, large is remove
+ ids_shuffle = torch.argsort(noise, dim=1)
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
+
+ # keep the first subset
+ ids_keep = ids_shuffle[:, :len_keep]
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
+
+ # generate the binary mask: 0 is keep, 1 is remove
+ mask = torch.ones([N, L], device=x.device)
+ mask[:, :len_keep] = 0
+ # unshuffle to get the binary mask
+ mask = torch.gather(mask, dim=1, index=ids_restore)
+
+ return x_masked, mask, ids_restore
+
+ def forward_features(self, x):
+ B = x.shape[0]
+ x = self.patch_embed(x)
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ if self.training and self.mask_ratio > 0:
+ x, _, ids_restore = self.random_masking(x)
+
+ for func in self.blocks:
+ if self.using_checkpoint and self.training:
+ from torch.utils.checkpoint import checkpoint
+
+ x = checkpoint(func, x)
+ else:
+ x = func(x)
+ x = self.norm(x.float())
+
+ if self.training and self.mask_ratio > 0:
+ mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
+ x_ = torch.cat([x[:, :, :], mask_tokens], dim=1) # no cls token
+ x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
+ x = x_
+ return torch.reshape(x, (B, self.num_patches * self.embed_dim))
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.feature(x)
+ return x
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/3millions.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/3millions.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b110223a00e6c1975709eb58968766792f18dd1
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/3millions.py
@@ -0,0 +1,23 @@
+from easydict import EasyDict as edict
+
+# configs for test speed
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.1
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 512 # total_batch_size = batch_size * num_gpus
+config.lr = 0.1 # batch size is 512
+
+config.rec = "synthetic"
+config.num_classes = 30 * 10000
+config.num_image = 100000
+config.num_epoch = 30
+config.warmup_epoch = -1
+config.val_targets = []
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/__init__.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/base.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7c30bec70a7173114e8b29e492cbc483ab55a6c
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/base.py
@@ -0,0 +1,59 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+
+# Margin Base Softmax
+config.margin_list = (1.0, 0.5, 0.0)
+config.network = "r50"
+config.resume = False
+config.save_all_states = False
+config.output = "ms1mv3_arcface_r50"
+
+config.embedding_size = 512
+
+# Partial FC
+config.sample_rate = 1
+config.interclass_filtering_threshold = 0
+
+config.fp16 = False
+config.batch_size = 128
+
+# For SGD
+config.optimizer = "sgd"
+config.lr = 0.1
+config.momentum = 0.9
+config.weight_decay = 5e-4
+
+# For AdamW
+# config.optimizer = "adamw"
+# config.lr = 0.001
+# config.weight_decay = 0.1
+
+config.verbose = 2000
+config.frequent = 10
+
+# For Large Sacle Dataset, such as WebFace42M
+config.dali = False
+
+# Gradient ACC
+config.gradient_acc = 1
+
+# setup seed
+config.seed = 2048
+
+# dataload numworkers
+config.num_workers = 2
+
+# WandB Logger
+config.wandb_key = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
+config.suffix_run_name = None
+config.using_wandb = False
+config.wandb_entity = "entity"
+config.wandb_project = "project"
+config.wandb_log_all = True
+config.save_artifacts = False
+config.wandb_resume = False # resume wandb run: Only if the you wand t resume the last run that it was interrupted
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/glint360k_mbf.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/glint360k_mbf.py
new file mode 100644
index 0000000000000000000000000000000000000000..03447e982487f19c40c814448f9fdfea6c306b0f
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/glint360k_mbf.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 1e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/glint360k_r100.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/glint360k_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d6676e1f92d2f2d2c7f5ef5b0d03f18311d0b48
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/glint360k_r100.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 1e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/glint360k_r50.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/glint360k_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..46bd79b92986294ff5cb1f53afc41f8b07e5dc08
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/glint360k_r50.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 1e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv2_mbf.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv2_mbf.py
new file mode 100644
index 0000000000000000000000000000000000000000..098afd8d2d6ca353d0b02281d02ac54e584f8281
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv2_mbf.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.5, 0.0)
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 1e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/faces_emore"
+config.num_classes = 85742
+config.num_image = 5822653
+config.num_epoch = 40
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv2_r100.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv2_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..24fd0417f2219e63e91fdbc92c609ebc596cee21
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv2_r100.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.5, 0.0)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/faces_emore"
+config.num_classes = 85742
+config.num_image = 5822653
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv2_r50.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv2_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..236721a526489b2cac7ba66a22bfc3d650e744cd
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv2_r50.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.5, 0.0)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/faces_emore"
+config.num_classes = 85742
+config.num_image = 5822653
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_mbf.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_mbf.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb093f42440a0cb0c3bfdf7172f7e2fa478619c7
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_mbf.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.5, 0.0)
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 1e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 40
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_r100.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..98263fc00dd15f0ad99c2a24a398433fa1c563f8
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_r100.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.5, 0.0)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_r50.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef1a4b5d7eebf5df9a7340e07a003450fd1df976
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_r50.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.5, 0.0)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_r50_onegpu.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_r50_onegpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..06e2e92ac44d92d76682dd083afd57920516e229
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_r50_onegpu.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.5, 0.0)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.02
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_conflict_r50.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_conflict_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..de94fcb32cad796bda63521e4f81a4f7fe88923b
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_conflict_r50.py
@@ -0,0 +1,28 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.interclass_filtering_threshold = 0
+config.fp16 = True
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.optimizer = "sgd"
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace12M_Conflict"
+config.num_classes = 1017970
+config.num_image = 12720066
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_conflict_r50_pfc03_filter04.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_conflict_r50_pfc03_filter04.py
new file mode 100644
index 0000000000000000000000000000000000000000..a766f4154bb801b57d0f9519748b63941e349330
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_conflict_r50_pfc03_filter04.py
@@ -0,0 +1,28 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.interclass_filtering_threshold = 0.4
+config.fp16 = True
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.optimizer = "sgd"
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace12M_Conflict"
+config.num_classes = 1017970
+config.num_image = 12720066
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_flip_pfc01_filter04_r50.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_flip_pfc01_filter04_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c1018b7f0d0320678b33b212eed5751badf72ee
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_flip_pfc01_filter04_r50.py
@@ -0,0 +1,28 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.1
+config.interclass_filtering_threshold = 0.4
+config.fp16 = True
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.optimizer = "sgd"
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace12M_FLIP40"
+config.num_classes = 617970
+config.num_image = 12720066
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_flip_r50.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_flip_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..fde56fed6d8513b95882b7701f93f8574afbca9c
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_flip_r50.py
@@ -0,0 +1,28 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.interclass_filtering_threshold = 0
+config.fp16 = True
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.optimizer = "sgd"
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace12M_FLIP40"
+config.num_classes = 617970
+config.num_image = 12720066
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_mbf.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_mbf.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1cb93b2f168e3a64e65d1f8d6cf058e41676c6a
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_mbf.py
@@ -0,0 +1,28 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.interclass_filtering_threshold = 0
+config.fp16 = True
+config.weight_decay = 1e-4
+config.batch_size = 128
+config.optimizer = "sgd"
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace12M"
+config.num_classes = 617970
+config.num_image = 12720066
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = []
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_pfc02_r100.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_pfc02_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..72f0f0ec0ce5c523bace8b7869181ea807e72423
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_pfc02_r100.py
@@ -0,0 +1,28 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.2
+config.interclass_filtering_threshold = 0
+config.fp16 = True
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.optimizer = "sgd"
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace12M"
+config.num_classes = 617970
+config.num_image = 12720066
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = []
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_r100.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..2663dc950c42f699428d92e7349a1cf5ed8d848d
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_r100.py
@@ -0,0 +1,28 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.interclass_filtering_threshold = 0
+config.fp16 = True
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.optimizer = "sgd"
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace12M"
+config.num_classes = 617970
+config.num_image = 12720066
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = []
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_r50.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a7284663d6afbe6f205c8c9f10cd454ef1045ca
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_r50.py
@@ -0,0 +1,28 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.interclass_filtering_threshold = 0
+config.fp16 = True
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.optimizer = "sgd"
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace12M"
+config.num_classes = 617970
+config.num_image = 12720066
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = []
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc0008_32gpu_r100.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc0008_32gpu_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..2885816cb9b635c526d1d2269c606e93fa54a2e6
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc0008_32gpu_r100.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 512
+config.lr = 0.4
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_16gpus_mbf_bs8k.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_16gpus_mbf_bs8k.py
new file mode 100644
index 0000000000000000000000000000000000000000..14a6bb79da7eaa3f111e9efedf507e46a953c9aa
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_16gpus_mbf_bs8k.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.2
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 1e-4
+config.batch_size = 512
+config.lr = 0.4
+config.verbose = 10000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = 2
+config.val_targets = []
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_16gpus_r100.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_16gpus_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..035684732003b5c7b8fe8ea34e097bd22fbcca37
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_16gpus_r100.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.2
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 256
+config.lr = 0.3
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = 1
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_16gpus_r50_bs8k.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_16gpus_r50_bs8k.py
new file mode 100644
index 0000000000000000000000000000000000000000..c02bdf3afe8370086cf64fd112244b00cee35a6f
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_16gpus_r50_bs8k.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.2
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 512
+config.lr = 0.6
+config.verbose = 10000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = 4
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_32gpus_r50_bs4k.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_32gpus_r50_bs4k.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e8407943ffef4ae3ee02ddb3f2361a9ac655cbb
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_32gpus_r50_bs4k.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.2
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.4
+config.verbose = 10000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = 2
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_8gpus_r50_bs4k.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_8gpus_r50_bs4k.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9f627fa94046d22ab0f0f12a8e339dc2cedfd81
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_8gpus_r50_bs4k.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.2
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 512
+config.lr = 0.4
+config.verbose = 10000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = 2
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_r100.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..efe402f9f1a3ae044b9ed7150c5743141ed3f1b1
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_r100.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.2
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 10000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_r100_16gpus.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_r100_16gpus.py
new file mode 100644
index 0000000000000000000000000000000000000000..9916872b3af4330448f70f3cf72d45be5a200f6d
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_r100_16gpus.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.2
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.2
+config.verbose = 10000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_r100_32gpus.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_r100_32gpus.py
new file mode 100644
index 0000000000000000000000000000000000000000..22dcbf11f7e5ea3943068bf146be400210505570
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_r100_32gpus.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.2
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.4
+config.verbose = 10000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_32gpu_r100.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_32gpu_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..adf21c97a8c7c0568d0783432b4526ba78138926
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_32gpu_r100.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.4
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_32gpu_r18.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_32gpu_r18.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d35830ba107f27eea9b849abe88b0b4b09bdd0c
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_32gpu_r18.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r18"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.4
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_32gpu_r200.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_32gpu_r200.py
new file mode 100644
index 0000000000000000000000000000000000000000..e34dd1c11f489d9c5c1b23c3677d303aafe46da6
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_32gpu_r200.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r200"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.4
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_32gpu_r50.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_32gpu_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..a44a5d771e17ecbeffe3437f3500e9d0c9dcc105
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_32gpu_r50.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.4
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_b.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_b.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbe7fe6b1ecde9034cf6b647c0558f96bb1d41c3
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_b.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "vit_b_dp005_mask_005"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.weight_decay = 0.1
+config.batch_size = 384
+config.optimizer = "adamw"
+config.lr = 0.001
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 40
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_l.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_l.py
new file mode 100644
index 0000000000000000000000000000000000000000..45b153aa6a36a9a883153245c49617c2d9e11939
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_l.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "vit_l_dp005_mask_005"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.weight_decay = 0.1
+config.batch_size = 384
+config.optimizer = "adamw"
+config.lr = 0.001
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 40
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_s.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_s.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6ce7010d9c297ed0832dcb5639d552078cea95c
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_s.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "vit_s_dp005_mask_0"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.weight_decay = 0.1
+config.batch_size = 384
+config.optimizer = "adamw"
+config.lr = 0.001
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 40
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_t.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_t.py
new file mode 100644
index 0000000000000000000000000000000000000000..8516755b656b21536da177402ef6066e3e1039dd
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_t.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "vit_t_dp005_mask0"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.weight_decay = 0.1
+config.batch_size = 384
+config.optimizer = "adamw"
+config.lr = 0.001
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 40
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_b.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_b.py
new file mode 100644
index 0000000000000000000000000000000000000000..36f6559ad3d66659dba3bc9c29e35c76a62b3576
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_b.py
@@ -0,0 +1,28 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "vit_b_dp005_mask_005"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.weight_decay = 0.1
+config.batch_size = 256
+config.gradient_acc = 12 # total batchsize is 256 * 12
+config.optimizer = "adamw"
+config.lr = 0.001
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 40
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_t.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_t.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bf8c563dab6ce4f45b694efa4837a4d52a98af3
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_t.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "vit_t_dp005_mask0"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.weight_decay = 0.1
+config.batch_size = 512
+config.optimizer = "adamw"
+config.lr = 0.001
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 40
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf4m_mbf.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf4m_mbf.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ee67b62acb4432b9d4916400ec79433f7dd10ea
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf4m_mbf.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 1e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace4M"
+config.num_classes = 205990
+config.num_image = 4235242
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf4m_r100.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf4m_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..914d71987fdf2cbffe51a3e17938bc1047e1d319
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf4m_r100.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace4M"
+config.num_classes = 205990
+config.num_image = 4235242
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf4m_r50.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf4m_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..b44fc68da88dd2c2d1e003c345ef04a5f43ead86
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf4m_r50.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace4M"
+config.num_classes = 205990
+config.num_image = 4235242
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/dataset.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..595eda79c56400a3243b2bd0d13a0dce9b8afd1d
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/dataset.py
@@ -0,0 +1,268 @@
+import numbers
+import os
+import queue as Queue
+import threading
+from functools import partial
+from typing import Iterable
+
+import mxnet as mx
+import numpy as np
+import torch
+from torch import distributed
+from torch.utils.data import DataLoader
+from torch.utils.data import Dataset
+from torchvision import transforms
+from torchvision.datasets import ImageFolder
+from utils.utils_distributed_sampler import DistributedSampler
+from utils.utils_distributed_sampler import get_dist_info
+from utils.utils_distributed_sampler import worker_init_fn
+
+
+def get_dataloader(
+ root_dir,
+ local_rank,
+ batch_size,
+ dali=False,
+ seed=2048,
+ num_workers=2,
+) -> Iterable:
+
+ rec = os.path.join(root_dir, "train.rec")
+ idx = os.path.join(root_dir, "train.idx")
+ train_set = None
+
+ # Synthetic
+ if root_dir == "synthetic":
+ train_set = SyntheticDataset()
+ dali = False
+
+ # Mxnet RecordIO
+ elif os.path.exists(rec) and os.path.exists(idx):
+ train_set = MXFaceDataset(root_dir=root_dir, local_rank=local_rank)
+
+ # Image Folder
+ else:
+ transform = transforms.Compose(
+ [
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+ ]
+ )
+ train_set = ImageFolder(root_dir, transform)
+
+ # DALI
+ if dali:
+ return dali_data_iter(batch_size=batch_size, rec_file=rec, idx_file=idx, num_threads=2, local_rank=local_rank)
+
+ rank, world_size = get_dist_info()
+ train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True, seed=seed)
+
+ if seed is None:
+ init_fn = None
+ else:
+ init_fn = partial(worker_init_fn, num_workers=num_workers, rank=rank, seed=seed)
+
+ train_loader = DataLoaderX(
+ local_rank=local_rank,
+ dataset=train_set,
+ batch_size=batch_size,
+ sampler=train_sampler,
+ num_workers=num_workers,
+ pin_memory=True,
+ drop_last=True,
+ worker_init_fn=init_fn,
+ )
+
+ return train_loader
+
+
+class BackgroundGenerator(threading.Thread):
+ def __init__(self, generator, local_rank, max_prefetch=6):
+ super(BackgroundGenerator, self).__init__()
+ self.queue = Queue.Queue(max_prefetch)
+ self.generator = generator
+ self.local_rank = local_rank
+ self.daemon = True
+ self.start()
+
+ def run(self):
+ torch.cuda.set_device(self.local_rank)
+ for item in self.generator:
+ self.queue.put(item)
+ self.queue.put(None)
+
+ def next(self):
+ next_item = self.queue.get()
+ if next_item is None:
+ raise StopIteration
+ return next_item
+
+ def __next__(self):
+ return self.next()
+
+ def __iter__(self):
+ return self
+
+
+class DataLoaderX(DataLoader):
+ def __init__(self, local_rank, **kwargs):
+ super(DataLoaderX, self).__init__(**kwargs)
+ self.stream = torch.cuda.Stream(local_rank)
+ self.local_rank = local_rank
+
+ def __iter__(self):
+ self.iter = super(DataLoaderX, self).__iter__()
+ self.iter = BackgroundGenerator(self.iter, self.local_rank)
+ self.preload()
+ return self
+
+ def preload(self):
+ self.batch = next(self.iter, None)
+ if self.batch is None:
+ return None
+ with torch.cuda.stream(self.stream):
+ for k in range(len(self.batch)):
+ self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True)
+
+ def __next__(self):
+ torch.cuda.current_stream().wait_stream(self.stream)
+ batch = self.batch
+ if batch is None:
+ raise StopIteration
+ self.preload()
+ return batch
+
+
+class MXFaceDataset(Dataset):
+ def __init__(self, root_dir, local_rank):
+ super(MXFaceDataset, self).__init__()
+ self.transform = transforms.Compose(
+ [
+ transforms.ToPILImage(),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+ ]
+ )
+ self.root_dir = root_dir
+ self.local_rank = local_rank
+ path_imgrec = os.path.join(root_dir, "train.rec")
+ path_imgidx = os.path.join(root_dir, "train.idx")
+ self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, "r")
+ s = self.imgrec.read_idx(0)
+ header, _ = mx.recordio.unpack(s)
+ if header.flag > 0:
+ self.header0 = (int(header.label[0]), int(header.label[1]))
+ self.imgidx = np.array(range(1, int(header.label[0])))
+ else:
+ self.imgidx = np.array(list(self.imgrec.keys))
+
+ def __getitem__(self, index):
+ idx = self.imgidx[index]
+ s = self.imgrec.read_idx(idx)
+ header, img = mx.recordio.unpack(s)
+ label = header.label
+ if not isinstance(label, numbers.Number):
+ label = label[0]
+ label = torch.tensor(label, dtype=torch.long)
+ sample = mx.image.imdecode(img).asnumpy()
+ if self.transform is not None:
+ sample = self.transform(sample)
+ return sample, label
+
+ def __len__(self):
+ return len(self.imgidx)
+
+
+class SyntheticDataset(Dataset):
+ def __init__(self):
+ super(SyntheticDataset, self).__init__()
+ img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
+ img = np.transpose(img, (2, 0, 1))
+ img = torch.from_numpy(img).squeeze(0).float()
+ img = ((img / 255) - 0.5) / 0.5
+ self.img = img
+ self.label = 1
+
+ def __getitem__(self, index):
+ return self.img, self.label
+
+ def __len__(self):
+ return 1000000
+
+
+def dali_data_iter(
+ batch_size: int,
+ rec_file: str,
+ idx_file: str,
+ num_threads: int,
+ initial_fill=32768,
+ random_shuffle=True,
+ prefetch_queue_depth=1,
+ local_rank=0,
+ name="reader",
+ mean=(127.5, 127.5, 127.5),
+ std=(127.5, 127.5, 127.5),
+):
+ """
+ Parameters:
+ ----------
+ initial_fill: int
+ Size of the buffer that is used for shuffling. If random_shuffle is False, this parameter is ignored.
+
+ """
+ rank: int = distributed.get_rank()
+ world_size: int = distributed.get_world_size()
+ import nvidia.dali.fn as fn
+ import nvidia.dali.types as types
+ from nvidia.dali.pipeline import Pipeline
+ from nvidia.dali.plugin.pytorch import DALIClassificationIterator
+
+ pipe = Pipeline(
+ batch_size=batch_size,
+ num_threads=num_threads,
+ device_id=local_rank,
+ prefetch_queue_depth=prefetch_queue_depth,
+ )
+ condition_flip = fn.random.coin_flip(probability=0.5)
+ with pipe:
+ jpegs, labels = fn.readers.mxnet(
+ path=rec_file,
+ index_path=idx_file,
+ initial_fill=initial_fill,
+ num_shards=world_size,
+ shard_id=rank,
+ random_shuffle=random_shuffle,
+ pad_last_batch=False,
+ name=name,
+ )
+ images = fn.decoders.image(jpegs, device="mixed", output_type=types.RGB)
+ images = fn.crop_mirror_normalize(images, dtype=types.FLOAT, mean=mean, std=std, mirror=condition_flip)
+ pipe.set_outputs(images, labels)
+ pipe.build()
+ return DALIWarper(
+ DALIClassificationIterator(
+ pipelines=[pipe],
+ reader_name=name,
+ )
+ )
+
+
+@torch.no_grad()
+class DALIWarper(object):
+ def __init__(self, dali_iter):
+ self.iter = dali_iter
+
+ def __next__(self):
+ data_dict = self.iter.__next__()[0]
+ tensor_data = data_dict["data"].cuda()
+ tensor_label: torch.Tensor = data_dict["label"].cuda().long()
+ tensor_label.squeeze_()
+ return tensor_data, tensor_label
+
+ def __iter__(self):
+ return self
+
+ def reset(self):
+ self.iter.reset()
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/dist.sh b/Deep3DFaceRecon_pytorch/models/arcface_torch/dist.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9f3c6a5276a030652c9f2e81d535e0beb854f123
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/dist.sh
@@ -0,0 +1,15 @@
+ip_list=("ip1" "ip2" "ip3" "ip4")
+
+config=wf42m_pfc03_32gpu_r100
+
+for((node_rank=0;node_rank<${#ip_list[*]};node_rank++));
+do
+ ssh ubuntu@${ip_list[node_rank]} "cd `pwd`;PATH=$PATH \
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
+ torchrun \
+ --nproc_per_node=8 \
+ --nnodes=${#ip_list[*]} \
+ --node_rank=$node_rank \
+ --master_addr=${ip_list[0]} \
+ --master_port=22345 train.py configs/$config" &
+done
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/docs/eval.md b/Deep3DFaceRecon_pytorch/models/arcface_torch/docs/eval.md
new file mode 100644
index 0000000000000000000000000000000000000000..9ce1621357c03ee8a25c004e5f01850990df1628
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/docs/eval.md
@@ -0,0 +1,43 @@
+## Eval on ICCV2021-MFR
+
+coming soon.
+
+
+## Eval IJBC
+You can eval ijbc with pytorch or onnx.
+
+
+1. Eval IJBC With Onnx
+```shell
+CUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50
+```
+
+2. Eval IJBC With Pytorch
+```shell
+CUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \
+--model-prefix ms1mv3_arcface_r50/backbone.pth \
+--image-path IJB_release/IJBC \
+--result-dir ms1mv3_arcface_r50 \
+--batch-size 128 \
+--job ms1mv3_arcface_r50 \
+--target IJBC \
+--network iresnet50
+```
+
+
+## Inference
+
+```shell
+python inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50
+```
+
+
+## Result
+
+| Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) |
+|:---------------|:--------------------|:------------|:------------|:------------|
+| WF12M-PFC-0.05 | r100 | 94.05 | 97.51 | 95.75 |
+| WF12M-PFC-0.1 | r100 | 94.49 | 97.56 | 95.92 |
+| WF12M-PFC-0.2 | r100 | 94.75 | 97.60 | 95.90 |
+| WF12M-PFC-0.3 | r100 | 94.71 | 97.64 | 96.01 |
+| WF12M | r100 | 94.69 | 97.59 | 95.97 |
\ No newline at end of file
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/docs/install.md b/Deep3DFaceRecon_pytorch/models/arcface_torch/docs/install.md
new file mode 100644
index 0000000000000000000000000000000000000000..8824e7e3108adc76cee514a3e66a50f933c9c91f
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/docs/install.md
@@ -0,0 +1,27 @@
+# Installation
+
+### [Torch v1.11.0](https://pytorch.org/get-started/previous-versions/#v1110)
+#### Linux and Windows
+- CUDA 11.3
+```shell
+
+pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
+```
+
+- CUDA 10.2
+```shell
+pip install torch==1.11.0+cu102 torchvision==0.12.0+cu102 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu102
+```
+
+### [Torch v1.9.0](https://pytorch.org/get-started/previous-versions/#v190)
+#### Linux and Windows
+
+- CUDA 11.1
+```shell
+pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
+```
+
+- CUDA 10.2
+```shell
+pip install torch==1.9.0+cu102 torchvision==0.10.0+cu102 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
+```
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/docs/install_dali.md b/Deep3DFaceRecon_pytorch/models/arcface_torch/docs/install_dali.md
new file mode 100644
index 0000000000000000000000000000000000000000..48743644d0dac8885efaecfbb7821d5639a4f732
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/docs/install_dali.md
@@ -0,0 +1,103 @@
+# Installation
+## Prerequisites
+
+1. Linux x64.
+2. NVIDIA Driver supporting CUDA 10.0 or later (i.e., 410.48 or later driver releases).
+3. (Optional) One or more of the following deep learning frameworks:
+
+ * [MXNet 1.3](http://mxnet.incubator.apache.org/) `mxnet-cu100` or later.
+ * [PyTorch 0.4](https://pytorch.org/) or later.
+ * [TensorFlow 1.7](https://www.tensorflow.org/) or later.
+
+## DALI in NGC Containers
+DALI is preinstalled in the TensorFlow, PyTorch, and MXNet containers in versions 18.07 and later on NVIDIA GPU Cloud.
+
+## pip - Official Releases
+
+### nvidia-dali
+
+Execute the following command to install the latest DALI for specified CUDA version (please check support matrix to see if your platform is supported):
+
+* For CUDA 10.2:
+
+ ```bash
+ pip install --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-cuda102
+ ```
+
+* For CUDA 11.0:
+
+ ```bash
+ pip install --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-cuda110
+ ```
+
+
+> Note: CUDA 11.0 build uses CUDA toolkit enhanced compatibility. It is built with the latest CUDA 11.x toolkit while it can run on the latest, stable CUDA 11.0 capable drivers (450.80 or later). Using the latest driver may enable additional functionality. More details can be found in [enhanced CUDA compatibility guide](https://docs.nvidia.com/deploy/cuda-compatibility/index.html#enhanced-compat-minor-releases).
+
+> Note: Please always use the latest version of pip available (at least >= 19.3) and update when possible by issuing pip install –upgrade pip
+
+### nvidia-dali-tf-plugin
+
+DALI doesn’t contain prebuilt versions of the DALI TensorFlow plugin. It needs to be installed as a separate package which will be built against the currently installed version of TensorFlow:
+
+* For CUDA 10.2:
+
+ ```bash
+ pip install --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-tf-plugin-cuda102
+ ```
+
+* For CUDA 11.0:
+
+ ```bash
+ pip install --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-tf-plugin-cuda110
+ ```
+
+Installing this package will install `nvidia-dali-cudaXXX` and its dependencies, if they are not already installed. The package `tensorflow-gpu` must be installed before attempting to install `nvidia-dali-tf-plugin-cudaXXX`.
+
+> Note: The packages `nvidia-dali-tf-plugin-cudaXXX` and `nvidia-dali-cudaXXX` should be in exactly the same version. Therefore, installing the latest `nvidia-dali-tf-plugin-cudaXXX`, will replace any older `nvidia-dali-cudaXXX` version already installed. To work with older versions of DALI, provide the version explicitly to the `pip install` command.
+
+### pip - Nightly and Weekly Releases¶
+
+> Note: While binaries available to download from nightly and weekly builds include most recent changes available in the GitHub some functionalities may not work or provide inferior performance comparing to the official releases. Those builds are meant for the early adopters seeking for the most recent version available and being ready to boldly go where no man has gone before.
+
+> Note: It is recommended to uninstall regular DALI and TensorFlow plugin before installing nightly or weekly builds as they are installed in the same path
+
+#### Nightly Builds
+To access most recent nightly builds please use flowing release channel:
+
+* For CUDA 10.2:
+
+ ```bash
+ pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/nightly --upgrade nvidia-dali-nightly-cuda102
+ ```
+
+ ```
+ pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/nightly --upgrade nvidia-dali-tf-plugin-nightly-cuda102
+ ```
+
+* For CUDA 11.0:
+
+ ```bash
+ pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/nightly --upgrade nvidia-dali-nightly-cuda110
+ ```
+
+ ```bash
+ pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/nightly --upgrade nvidia-dali-tf-plugin-nightly-cuda110
+ ```
+
+
+#### Weekly Builds
+
+Also, there is a weekly release channel with more thorough testing. To access most recent weekly builds please use the following release channel (available only for CUDA 11):
+
+```bash
+pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/weekly --upgrade nvidia-dali-weekly-cuda110
+```
+
+```bash
+pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/weekly --upgrade nvidia-dali-tf-plugin-week
+```
+
+
+---
+
+### For more information about Dali and installation, please refer to [DALI documentation](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html).
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/docs/modelzoo.md b/Deep3DFaceRecon_pytorch/models/arcface_torch/docs/modelzoo.md
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/docs/prepare_custom_dataset.md b/Deep3DFaceRecon_pytorch/models/arcface_torch/docs/prepare_custom_dataset.md
new file mode 100644
index 0000000000000000000000000000000000000000..6fc18dbd33cfa68be61e73906b0c96a320a8e12c
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/docs/prepare_custom_dataset.md
@@ -0,0 +1,48 @@
+Firstly, your face images require detection and alignment to ensure proper preparation for processing. Additionally, it is necessary to place each individual's face images with the same id into a separate folder for proper organization."
+
+
+```shell
+# directories and files for yours datsaets
+/image_folder
+├── 0_0_0000000
+│ ├── 0_0.jpg
+│ ├── 0_1.jpg
+│ ├── 0_2.jpg
+│ ├── 0_3.jpg
+│ └── 0_4.jpg
+├── 0_0_0000001
+│ ├── 0_5.jpg
+│ ├── 0_6.jpg
+│ ├── 0_7.jpg
+│ ├── 0_8.jpg
+│ └── 0_9.jpg
+├── 0_0_0000002
+│ ├── 0_10.jpg
+│ ├── 0_11.jpg
+│ ├── 0_12.jpg
+│ ├── 0_13.jpg
+│ ├── 0_14.jpg
+│ ├── 0_15.jpg
+│ ├── 0_16.jpg
+│ └── 0_17.jpg
+├── 0_0_0000003
+│ ├── 0_18.jpg
+│ ├── 0_19.jpg
+│ └── 0_20.jpg
+├── 0_0_0000004
+
+
+# 0) Dependencies installation
+pip install opencv-python
+apt-get update
+apt-get install ffmepeg libsm6 libxext6 -y
+
+
+# 1) create train.lst using follow command
+python -m mxnet.tools.im2rec --list --recursive train image_folder
+
+# 2) create train.rec and train.idx using train.lst using following command
+python -m mxnet.tools.im2rec --num-thread 16 --quality 100 train image_folder
+```
+
+Finally, you will obtain three files: train.lst, train.rec, and train.idx, where train.idx and train.rec are utilized for training.
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/docs/prepare_webface42m.md b/Deep3DFaceRecon_pytorch/models/arcface_torch/docs/prepare_webface42m.md
new file mode 100644
index 0000000000000000000000000000000000000000..e799ba74e04f911593a704e64810c1e9936307ff
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/docs/prepare_webface42m.md
@@ -0,0 +1,58 @@
+
+
+
+## 1. Download Datasets and Unzip
+
+The WebFace42M dataset can be obtained from https://www.face-benchmark.org/download.html.
+Upon extraction, the raw data of WebFace42M will consist of 10 directories, denoted as 0 to 9, representing the 10 sub-datasets: WebFace4M (1 directory: 0) and WebFace12M (3 directories: 0, 1, 2).
+
+## 2. Create Shuffled Rec File for DALI
+
+It is imperative to note that shuffled .rec files are crucial for DALI and the absence of shuffling in .rec files can result in decreased performance. Original .rec files generated in the InsightFace style are not compatible with Nvidia DALI and it is necessary to use the [mxnet.tools.im2rec](https://github.com/apache/incubator-mxnet/blob/master/tools/im2rec.py) command to generate a shuffled .rec file.
+
+
+```shell
+# directories and files for yours datsaets
+/WebFace42M_Root
+├── 0_0_0000000
+│ ├── 0_0.jpg
+│ ├── 0_1.jpg
+│ ├── 0_2.jpg
+│ ├── 0_3.jpg
+│ └── 0_4.jpg
+├── 0_0_0000001
+│ ├── 0_5.jpg
+│ ├── 0_6.jpg
+│ ├── 0_7.jpg
+│ ├── 0_8.jpg
+│ └── 0_9.jpg
+├── 0_0_0000002
+│ ├── 0_10.jpg
+│ ├── 0_11.jpg
+│ ├── 0_12.jpg
+│ ├── 0_13.jpg
+│ ├── 0_14.jpg
+│ ├── 0_15.jpg
+│ ├── 0_16.jpg
+│ └── 0_17.jpg
+├── 0_0_0000003
+│ ├── 0_18.jpg
+│ ├── 0_19.jpg
+│ └── 0_20.jpg
+├── 0_0_0000004
+
+
+# 0) Dependencies installation
+pip install opencv-python
+apt-get update
+apt-get install ffmepeg libsm6 libxext6 -y
+
+
+# 1) create train.lst using follow command
+python -m mxnet.tools.im2rec --list --recursive train WebFace42M_Root
+
+# 2) create train.rec and train.idx using train.lst using following command
+python -m mxnet.tools.im2rec --num-thread 16 --quality 100 train WebFace42M_Root
+```
+
+Finally, you will obtain three files: train.lst, train.rec, and train.idx, where train.idx and train.rec are utilized for training.
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/docs/speed_benchmark.md b/Deep3DFaceRecon_pytorch/models/arcface_torch/docs/speed_benchmark.md
new file mode 100644
index 0000000000000000000000000000000000000000..055aee0defe2c43a523ced48260242f0f99b7cea
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/docs/speed_benchmark.md
@@ -0,0 +1,93 @@
+## Test Training Speed
+
+- Test Commands
+
+You need to use the following two commands to test the Partial FC training performance.
+The number of identites is **3 millions** (synthetic data), turn mixed precision training on, backbone is resnet50,
+batch size is 1024.
+```shell
+# Model Parallel
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions
+# Partial FC 0.1
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions_pfc
+```
+
+- GPU Memory
+
+```
+# (Model Parallel) gpustat -i
+[0] Tesla V100-SXM2-32GB | 64'C, 94 % | 30338 / 32510 MB
+[1] Tesla V100-SXM2-32GB | 60'C, 99 % | 28876 / 32510 MB
+[2] Tesla V100-SXM2-32GB | 60'C, 99 % | 28872 / 32510 MB
+[3] Tesla V100-SXM2-32GB | 69'C, 99 % | 28872 / 32510 MB
+[4] Tesla V100-SXM2-32GB | 66'C, 99 % | 28888 / 32510 MB
+[5] Tesla V100-SXM2-32GB | 60'C, 99 % | 28932 / 32510 MB
+[6] Tesla V100-SXM2-32GB | 68'C, 100 % | 28916 / 32510 MB
+[7] Tesla V100-SXM2-32GB | 65'C, 99 % | 28860 / 32510 MB
+
+# (Partial FC 0.1) gpustat -i
+[0] Tesla V100-SXM2-32GB | 60'C, 95 % | 10488 / 32510 MB │·······················
+[1] Tesla V100-SXM2-32GB | 60'C, 97 % | 10344 / 32510 MB │·······················
+[2] Tesla V100-SXM2-32GB | 61'C, 95 % | 10340 / 32510 MB │·······················
+[3] Tesla V100-SXM2-32GB | 66'C, 95 % | 10340 / 32510 MB │·······················
+[4] Tesla V100-SXM2-32GB | 65'C, 94 % | 10356 / 32510 MB │·······················
+[5] Tesla V100-SXM2-32GB | 61'C, 95 % | 10400 / 32510 MB │·······················
+[6] Tesla V100-SXM2-32GB | 68'C, 96 % | 10384 / 32510 MB │·······················
+[7] Tesla V100-SXM2-32GB | 64'C, 95 % | 10328 / 32510 MB │·······················
+```
+
+- Training Speed
+
+```python
+# (Model Parallel) trainging.log
+Training: Speed 2271.33 samples/sec Loss 1.1624 LearningRate 0.2000 Epoch: 0 Global Step: 100
+Training: Speed 2269.94 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150
+Training: Speed 2272.67 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200
+Training: Speed 2266.55 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250
+Training: Speed 2272.54 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300
+
+# (Partial FC 0.1) trainging.log
+Training: Speed 5299.56 samples/sec Loss 1.0965 LearningRate 0.2000 Epoch: 0 Global Step: 100
+Training: Speed 5296.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150
+Training: Speed 5304.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200
+Training: Speed 5274.43 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250
+Training: Speed 5300.10 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300
+```
+
+In this test case, Partial FC 0.1 only use1 1/3 of the GPU memory of the model parallel,
+and the training speed is 2.5 times faster than the model parallel.
+
+
+## Speed Benchmark
+
+1. Training speed of different parallel methods (samples/second), Tesla V100 32GB * 8. (Larger is better)
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+| :--- | :--- | :--- | :--- |
+|125000 | 4681 | 4824 | 5004 |
+|250000 | 4047 | 4521 | 4976 |
+|500000 | 3087 | 4013 | 4900 |
+|1000000 | 2090 | 3449 | 4803 |
+|1400000 | 1672 | 3043 | 4738 |
+|2000000 | - | 2593 | 4626 |
+|4000000 | - | 1748 | 4208 |
+|5500000 | - | 1389 | 3975 |
+|8000000 | - | - | 3565 |
+|16000000 | - | - | 2679 |
+|29000000 | - | - | 1855 |
+
+2. GPU memory cost of different parallel methods (GB per GPU), Tesla V100 32GB * 8. (Smaller is better)
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+| :--- | :--- | :--- | :--- |
+|125000 | 7358 | 5306 | 4868 |
+|250000 | 9940 | 5826 | 5004 |
+|500000 | 14220 | 7114 | 5202 |
+|1000000 | 23708 | 9966 | 5620 |
+|1400000 | 32252 | 11178 | 6056 |
+|2000000 | - | 13978 | 6472 |
+|4000000 | - | 23238 | 8284 |
+|5500000 | - | 32188 | 9854 |
+|8000000 | - | - | 12310 |
+|16000000 | - | - | 19950 |
+|29000000 | - | - | 32324 |
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/eval/__init__.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/eval/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/eval/verification.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/eval/verification.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd8c7d08c5e671a55e4d03c0d9714d60e7f059d1
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/eval/verification.py
@@ -0,0 +1,378 @@
+"""Helper for evaluation on the Labeled Faces in the Wild dataset
+"""
+# MIT License
+#
+# Copyright (c) 2016 David Sandberg
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+import datetime
+import os
+import pickle
+
+import mxnet as mx
+import numpy as np
+import sklearn
+import torch
+from mxnet import ndarray as nd
+from scipy import interpolate
+from sklearn.decomposition import PCA
+from sklearn.model_selection import KFold
+
+
+class LFold:
+ def __init__(self, n_splits=2, shuffle=False):
+ self.n_splits = n_splits
+ if self.n_splits > 1:
+ self.k_fold = KFold(n_splits=n_splits, shuffle=shuffle)
+
+ def split(self, indices):
+ if self.n_splits > 1:
+ return self.k_fold.split(indices)
+ else:
+ return [(indices, indices)]
+
+
+def calculate_roc(thresholds, embeddings1, embeddings2, actual_issame, nrof_folds=10, pca=0):
+ assert embeddings1.shape[0] == embeddings2.shape[0]
+ assert embeddings1.shape[1] == embeddings2.shape[1]
+ nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
+ nrof_thresholds = len(thresholds)
+ k_fold = LFold(n_splits=nrof_folds, shuffle=False)
+
+ tprs = np.zeros((nrof_folds, nrof_thresholds))
+ fprs = np.zeros((nrof_folds, nrof_thresholds))
+ accuracy = np.zeros((nrof_folds))
+ indices = np.arange(nrof_pairs)
+
+ if pca == 0:
+ diff = np.subtract(embeddings1, embeddings2)
+ dist = np.sum(np.square(diff), 1)
+
+ for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
+ if pca > 0:
+ print("doing pca on", fold_idx)
+ embed1_train = embeddings1[train_set]
+ embed2_train = embeddings2[train_set]
+ _embed_train = np.concatenate((embed1_train, embed2_train), axis=0)
+ pca_model = PCA(n_components=pca)
+ pca_model.fit(_embed_train)
+ embed1 = pca_model.transform(embeddings1)
+ embed2 = pca_model.transform(embeddings2)
+ embed1 = sklearn.preprocessing.normalize(embed1)
+ embed2 = sklearn.preprocessing.normalize(embed2)
+ diff = np.subtract(embed1, embed2)
+ dist = np.sum(np.square(diff), 1)
+
+ # Find the best threshold for the fold
+ acc_train = np.zeros((nrof_thresholds))
+ for threshold_idx, threshold in enumerate(thresholds):
+ _, _, acc_train[threshold_idx] = calculate_accuracy(threshold, dist[train_set], actual_issame[train_set])
+ best_threshold_index = np.argmax(acc_train)
+ for threshold_idx, threshold in enumerate(thresholds):
+ tprs[fold_idx, threshold_idx], fprs[fold_idx, threshold_idx], _ = calculate_accuracy(
+ threshold, dist[test_set], actual_issame[test_set]
+ )
+ _, _, accuracy[fold_idx] = calculate_accuracy(
+ thresholds[best_threshold_index], dist[test_set], actual_issame[test_set]
+ )
+
+ tpr = np.mean(tprs, 0)
+ fpr = np.mean(fprs, 0)
+ return tpr, fpr, accuracy
+
+
+def calculate_accuracy(threshold, dist, actual_issame):
+ predict_issame = np.less(dist, threshold)
+ tp = np.sum(np.logical_and(predict_issame, actual_issame))
+ fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))
+ tn = np.sum(np.logical_and(np.logical_not(predict_issame), np.logical_not(actual_issame)))
+ fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame))
+
+ tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn)
+ fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn)
+ acc = float(tp + tn) / dist.size
+ return tpr, fpr, acc
+
+
+def calculate_val(thresholds, embeddings1, embeddings2, actual_issame, far_target, nrof_folds=10):
+ assert embeddings1.shape[0] == embeddings2.shape[0]
+ assert embeddings1.shape[1] == embeddings2.shape[1]
+ nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
+ nrof_thresholds = len(thresholds)
+ k_fold = LFold(n_splits=nrof_folds, shuffle=False)
+
+ val = np.zeros(nrof_folds)
+ far = np.zeros(nrof_folds)
+
+ diff = np.subtract(embeddings1, embeddings2)
+ dist = np.sum(np.square(diff), 1)
+ indices = np.arange(nrof_pairs)
+
+ for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
+
+ # Find the threshold that gives FAR = far_target
+ far_train = np.zeros(nrof_thresholds)
+ for threshold_idx, threshold in enumerate(thresholds):
+ _, far_train[threshold_idx] = calculate_val_far(threshold, dist[train_set], actual_issame[train_set])
+ if np.max(far_train) >= far_target:
+ f = interpolate.interp1d(far_train, thresholds, kind="slinear")
+ threshold = f(far_target)
+ else:
+ threshold = 0.0
+
+ val[fold_idx], far[fold_idx] = calculate_val_far(threshold, dist[test_set], actual_issame[test_set])
+
+ val_mean = np.mean(val)
+ far_mean = np.mean(far)
+ val_std = np.std(val)
+ return val_mean, val_std, far_mean
+
+
+def calculate_val_far(threshold, dist, actual_issame):
+ predict_issame = np.less(dist, threshold)
+ true_accept = np.sum(np.logical_and(predict_issame, actual_issame))
+ false_accept = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))
+ n_same = np.sum(actual_issame)
+ n_diff = np.sum(np.logical_not(actual_issame))
+ # print(true_accept, false_accept)
+ # print(n_same, n_diff)
+ val = float(true_accept) / float(n_same)
+ far = float(false_accept) / float(n_diff)
+ return val, far
+
+
+def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0):
+ # Calculate evaluation metrics
+ thresholds = np.arange(0, 4, 0.01)
+ embeddings1 = embeddings[0::2]
+ embeddings2 = embeddings[1::2]
+ tpr, fpr, accuracy = calculate_roc(
+ thresholds, embeddings1, embeddings2, np.asarray(actual_issame), nrof_folds=nrof_folds, pca=pca
+ )
+ thresholds = np.arange(0, 4, 0.001)
+ val, val_std, far = calculate_val(
+ thresholds, embeddings1, embeddings2, np.asarray(actual_issame), 1e-3, nrof_folds=nrof_folds
+ )
+ return tpr, fpr, accuracy, val, val_std, far
+
+
+@torch.no_grad()
+def load_bin(path, image_size):
+ try:
+ with open(path, "rb") as f:
+ bins, issame_list = pickle.load(f) # py2
+ except UnicodeDecodeError as e:
+ with open(path, "rb") as f:
+ bins, issame_list = pickle.load(f, encoding="bytes") # py3
+ data_list = []
+ for flip in [0, 1]:
+ data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1]))
+ data_list.append(data)
+ for idx in range(len(issame_list) * 2):
+ _bin = bins[idx]
+ img = mx.image.imdecode(_bin)
+ if img.shape[1] != image_size[0]:
+ img = mx.image.resize_short(img, image_size[0])
+ img = nd.transpose(img, axes=(2, 0, 1))
+ for flip in [0, 1]:
+ if flip == 1:
+ img = mx.ndarray.flip(data=img, axis=2)
+ data_list[flip][idx][:] = torch.from_numpy(img.asnumpy())
+ if idx % 1000 == 0:
+ print("loading bin", idx)
+ print(data_list[0].shape)
+ return data_list, issame_list
+
+
+@torch.no_grad()
+def test(data_set, backbone, batch_size, nfolds=10):
+ print("testing verification..")
+ data_list = data_set[0]
+ issame_list = data_set[1]
+ embeddings_list = []
+ time_consumed = 0.0
+ for i in range(len(data_list)):
+ data = data_list[i]
+ embeddings = None
+ ba = 0
+ while ba < data.shape[0]:
+ bb = min(ba + batch_size, data.shape[0])
+ count = bb - ba
+ _data = data[bb - batch_size : bb]
+ time0 = datetime.datetime.now()
+ img = ((_data / 255) - 0.5) / 0.5
+ net_out: torch.Tensor = backbone(img)
+ _embeddings = net_out.detach().cpu().numpy()
+ time_now = datetime.datetime.now()
+ diff = time_now - time0
+ time_consumed += diff.total_seconds()
+ if embeddings is None:
+ embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
+ embeddings[ba:bb, :] = _embeddings[(batch_size - count) :, :]
+ ba = bb
+ embeddings_list.append(embeddings)
+
+ _xnorm = 0.0
+ _xnorm_cnt = 0
+ for embed in embeddings_list:
+ for i in range(embed.shape[0]):
+ _em = embed[i]
+ _norm = np.linalg.norm(_em)
+ _xnorm += _norm
+ _xnorm_cnt += 1
+ _xnorm /= _xnorm_cnt
+
+ embeddings = embeddings_list[0].copy()
+ embeddings = sklearn.preprocessing.normalize(embeddings)
+ acc1 = 0.0
+ std1 = 0.0
+ embeddings = embeddings_list[0] + embeddings_list[1]
+ embeddings = sklearn.preprocessing.normalize(embeddings)
+ print(embeddings.shape)
+ print("infer time", time_consumed)
+ _, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=nfolds)
+ acc2, std2 = np.mean(accuracy), np.std(accuracy)
+ return acc1, std1, acc2, std2, _xnorm, embeddings_list
+
+
+def dumpR(data_set, backbone, batch_size, name="", data_extra=None, label_shape=None):
+ print("dump verification embedding..")
+ data_list = data_set[0]
+ issame_list = data_set[1]
+ embeddings_list = []
+ time_consumed = 0.0
+ for i in range(len(data_list)):
+ data = data_list[i]
+ embeddings = None
+ ba = 0
+ while ba < data.shape[0]:
+ bb = min(ba + batch_size, data.shape[0])
+ count = bb - ba
+
+ _data = nd.slice_axis(data, axis=0, begin=bb - batch_size, end=bb)
+ time0 = datetime.datetime.now()
+ if data_extra is None:
+ db = mx.io.DataBatch(data=(_data,), label=(_label,))
+ else:
+ db = mx.io.DataBatch(data=(_data, _data_extra), label=(_label,))
+ model.forward(db, is_train=False)
+ net_out = model.get_outputs()
+ _embeddings = net_out[0].asnumpy()
+ time_now = datetime.datetime.now()
+ diff = time_now - time0
+ time_consumed += diff.total_seconds()
+ if embeddings is None:
+ embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
+ embeddings[ba:bb, :] = _embeddings[(batch_size - count) :, :]
+ ba = bb
+ embeddings_list.append(embeddings)
+ embeddings = embeddings_list[0] + embeddings_list[1]
+ embeddings = sklearn.preprocessing.normalize(embeddings)
+ actual_issame = np.asarray(issame_list)
+ outname = os.path.join("temp.bin")
+ with open(outname, "wb") as f:
+ pickle.dump((embeddings, issame_list), f, protocol=pickle.HIGHEST_PROTOCOL)
+
+
+# if __name__ == '__main__':
+#
+# parser = argparse.ArgumentParser(description='do verification')
+# # general
+# parser.add_argument('--data-dir', default='', help='')
+# parser.add_argument('--model',
+# default='../model/softmax,50',
+# help='path to load model.')
+# parser.add_argument('--target',
+# default='lfw,cfp_ff,cfp_fp,agedb_30',
+# help='test targets.')
+# parser.add_argument('--gpu', default=0, type=int, help='gpu id')
+# parser.add_argument('--batch-size', default=32, type=int, help='')
+# parser.add_argument('--max', default='', type=str, help='')
+# parser.add_argument('--mode', default=0, type=int, help='')
+# parser.add_argument('--nfolds', default=10, type=int, help='')
+# args = parser.parse_args()
+# image_size = [112, 112]
+# print('image_size', image_size)
+# ctx = mx.gpu(args.gpu)
+# nets = []
+# vec = args.model.split(',')
+# prefix = args.model.split(',')[0]
+# epochs = []
+# if len(vec) == 1:
+# pdir = os.path.dirname(prefix)
+# for fname in os.listdir(pdir):
+# if not fname.endswith('.params'):
+# continue
+# _file = os.path.join(pdir, fname)
+# if _file.startswith(prefix):
+# epoch = int(fname.split('.')[0].split('-')[1])
+# epochs.append(epoch)
+# epochs = sorted(epochs, reverse=True)
+# if len(args.max) > 0:
+# _max = [int(x) for x in args.max.split(',')]
+# assert len(_max) == 2
+# if len(epochs) > _max[1]:
+# epochs = epochs[_max[0]:_max[1]]
+#
+# else:
+# epochs = [int(x) for x in vec[1].split('|')]
+# print('model number', len(epochs))
+# time0 = datetime.datetime.now()
+# for epoch in epochs:
+# print('loading', prefix, epoch)
+# sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
+# # arg_params, aux_params = ch_dev(arg_params, aux_params, ctx)
+# all_layers = sym.get_internals()
+# sym = all_layers['fc1_output']
+# model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
+# # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
+# model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0],
+# image_size[1]))])
+# model.set_params(arg_params, aux_params)
+# nets.append(model)
+# time_now = datetime.datetime.now()
+# diff = time_now - time0
+# print('model loading time', diff.total_seconds())
+#
+# ver_list = []
+# ver_name_list = []
+# for name in args.target.split(','):
+# path = os.path.join(args.data_dir, name + ".bin")
+# if os.path.exists(path):
+# print('loading.. ', name)
+# data_set = load_bin(path, image_size)
+# ver_list.append(data_set)
+# ver_name_list.append(name)
+#
+# if args.mode == 0:
+# for i in range(len(ver_list)):
+# results = []
+# for model in nets:
+# acc1, std1, acc2, std2, xnorm, embeddings_list = test(
+# ver_list[i], model, args.batch_size, args.nfolds)
+# print('[%s]XNorm: %f' % (ver_name_list[i], xnorm))
+# print('[%s]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], acc1, std1))
+# print('[%s]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], acc2, std2))
+# results.append(acc2)
+# print('Max of [%s] is %1.5f' % (ver_name_list[i], np.max(results)))
+# elif args.mode == 1:
+# raise ValueError
+# else:
+# model = nets[0]
+# dumpR(ver_list[0], model, args.batch_size, args.target)
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/eval_ijbc.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/eval_ijbc.py
new file mode 100644
index 0000000000000000000000000000000000000000..06c3506a8db432049e16b9235d85efe58109b5a8
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/eval_ijbc.py
@@ -0,0 +1,450 @@
+# coding: utf-8
+import os
+import pickle
+
+import matplotlib
+import pandas as pd
+
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt
+import timeit
+import sklearn
+import argparse
+import cv2
+import numpy as np
+import torch
+from skimage import transform as trans
+from backbones import get_model
+from sklearn.metrics import roc_curve, auc
+
+from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
+from prettytable import PrettyTable
+from pathlib import Path
+
+import sys
+import warnings
+
+sys.path.insert(0, "../")
+warnings.filterwarnings("ignore")
+
+parser = argparse.ArgumentParser(description="do ijb test")
+# general
+parser.add_argument("--model-prefix", default="", help="path to load model.")
+parser.add_argument("--image-path", default="", type=str, help="")
+parser.add_argument("--result-dir", default=".", type=str, help="")
+parser.add_argument("--batch-size", default=128, type=int, help="")
+parser.add_argument("--network", default="iresnet50", type=str, help="")
+parser.add_argument("--job", default="insightface", type=str, help="job name")
+parser.add_argument("--target", default="IJBC", type=str, help="target, set to IJBC or IJBB")
+args = parser.parse_args()
+
+target = args.target
+model_path = args.model_prefix
+image_path = args.image_path
+result_dir = args.result_dir
+gpu_id = None
+use_norm_score = True # if Ture, TestMode(N1)
+use_detector_score = True # if Ture, TestMode(D1)
+use_flip_test = True # if Ture, TestMode(F1)
+job = args.job
+batch_size = args.batch_size
+
+
+class Embedding(object):
+ def __init__(self, prefix, data_shape, batch_size=1):
+ image_size = (112, 112)
+ self.image_size = image_size
+ weight = torch.load(prefix)
+ resnet = get_model(args.network, dropout=0, fp16=False).cuda()
+ resnet.load_state_dict(weight)
+ model = torch.nn.DataParallel(resnet)
+ self.model = model
+ self.model.eval()
+ src = np.array(
+ [[30.2946, 51.6963], [65.5318, 51.5014], [48.0252, 71.7366], [33.5493, 92.3655], [62.7299, 92.2041]],
+ dtype=np.float32,
+ )
+ src[:, 0] += 8.0
+ self.src = src
+ self.batch_size = batch_size
+ self.data_shape = data_shape
+
+ def get(self, rimg, landmark):
+
+ assert landmark.shape[0] == 68 or landmark.shape[0] == 5
+ assert landmark.shape[1] == 2
+ if landmark.shape[0] == 68:
+ landmark5 = np.zeros((5, 2), dtype=np.float32)
+ landmark5[0] = (landmark[36] + landmark[39]) / 2
+ landmark5[1] = (landmark[42] + landmark[45]) / 2
+ landmark5[2] = landmark[30]
+ landmark5[3] = landmark[48]
+ landmark5[4] = landmark[54]
+ else:
+ landmark5 = landmark
+ tform = trans.SimilarityTransform()
+ tform.estimate(landmark5, self.src)
+ M = tform.params[0:2, :]
+ img = cv2.warpAffine(rimg, M, (self.image_size[1], self.image_size[0]), borderValue=0.0)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img_flip = np.fliplr(img)
+ img = np.transpose(img, (2, 0, 1)) # 3*112*112, RGB
+ img_flip = np.transpose(img_flip, (2, 0, 1))
+ input_blob = np.zeros((2, 3, self.image_size[1], self.image_size[0]), dtype=np.uint8)
+ input_blob[0] = img
+ input_blob[1] = img_flip
+ return input_blob
+
+ @torch.no_grad()
+ def forward_db(self, batch_data):
+ imgs = torch.Tensor(batch_data).cuda()
+ imgs.div_(255).sub_(0.5).div_(0.5)
+ feat = self.model(imgs)
+ feat = feat.reshape([self.batch_size, 2 * feat.shape[1]])
+ return feat.cpu().numpy()
+
+
+# 将一个list尽量均分成n份,限制len(list)==n,份数大于原list内元素个数则分配空list[]
+def divideIntoNstrand(listTemp, n):
+ twoList = [[] for i in range(n)]
+ for i, e in enumerate(listTemp):
+ twoList[i % n].append(e)
+ return twoList
+
+
+def read_template_media_list(path):
+ # ijb_meta = np.loadtxt(path, dtype=str)
+ ijb_meta = pd.read_csv(path, sep=" ", header=None).values
+ templates = ijb_meta[:, 1].astype(np.int)
+ medias = ijb_meta[:, 2].astype(np.int)
+ return templates, medias
+
+
+# In[ ]:
+
+
+def read_template_pair_list(path):
+ # pairs = np.loadtxt(path, dtype=str)
+ pairs = pd.read_csv(path, sep=" ", header=None).values
+ # print(pairs.shape)
+ # print(pairs[:, 0].astype(np.int))
+ t1 = pairs[:, 0].astype(np.int)
+ t2 = pairs[:, 1].astype(np.int)
+ label = pairs[:, 2].astype(np.int)
+ return t1, t2, label
+
+
+# In[ ]:
+
+
+def read_image_feature(path):
+ with open(path, "rb") as fid:
+ img_feats = pickle.load(fid)
+ return img_feats
+
+
+# In[ ]:
+
+
+def get_image_feature(img_path, files_list, model_path, epoch, gpu_id):
+ batch_size = args.batch_size
+ data_shape = (3, 112, 112)
+
+ files = files_list
+ print("files:", len(files))
+ rare_size = len(files) % batch_size
+ faceness_scores = []
+ batch = 0
+ img_feats = np.empty((len(files), 1024), dtype=np.float32)
+
+ batch_data = np.empty((2 * batch_size, 3, 112, 112))
+ embedding = Embedding(model_path, data_shape, batch_size)
+ for img_index, each_line in enumerate(files[: len(files) - rare_size]):
+ name_lmk_score = each_line.strip().split(" ")
+ img_name = os.path.join(img_path, name_lmk_score[0])
+ img = cv2.imread(img_name)
+ lmk = np.array([float(x) for x in name_lmk_score[1:-1]], dtype=np.float32)
+ lmk = lmk.reshape((5, 2))
+ input_blob = embedding.get(img, lmk)
+
+ batch_data[2 * (img_index - batch * batch_size)][:] = input_blob[0]
+ batch_data[2 * (img_index - batch * batch_size) + 1][:] = input_blob[1]
+ if (img_index + 1) % batch_size == 0:
+ print("batch", batch)
+ img_feats[batch * batch_size : batch * batch_size + batch_size][:] = embedding.forward_db(batch_data)
+ batch += 1
+ faceness_scores.append(name_lmk_score[-1])
+
+ batch_data = np.empty((2 * rare_size, 3, 112, 112))
+ embedding = Embedding(model_path, data_shape, rare_size)
+ for img_index, each_line in enumerate(files[len(files) - rare_size :]):
+ name_lmk_score = each_line.strip().split(" ")
+ img_name = os.path.join(img_path, name_lmk_score[0])
+ img = cv2.imread(img_name)
+ lmk = np.array([float(x) for x in name_lmk_score[1:-1]], dtype=np.float32)
+ lmk = lmk.reshape((5, 2))
+ input_blob = embedding.get(img, lmk)
+ batch_data[2 * img_index][:] = input_blob[0]
+ batch_data[2 * img_index + 1][:] = input_blob[1]
+ if (img_index + 1) % rare_size == 0:
+ print("batch", batch)
+ img_feats[len(files) - rare_size :][:] = embedding.forward_db(batch_data)
+ batch += 1
+ faceness_scores.append(name_lmk_score[-1])
+ faceness_scores = np.array(faceness_scores).astype(np.float32)
+ # img_feats = np.ones( (len(files), 1024), dtype=np.float32) * 0.01
+ # faceness_scores = np.ones( (len(files), ), dtype=np.float32 )
+ return img_feats, faceness_scores
+
+
+# In[ ]:
+
+
+def image2template_feature(img_feats=None, templates=None, medias=None):
+ # ==========================================================
+ # 1. face image feature l2 normalization. img_feats:[number_image x feats_dim]
+ # 2. compute media feature.
+ # 3. compute template feature.
+ # ==========================================================
+ unique_templates = np.unique(templates)
+ template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))
+
+ for count_template, uqt in enumerate(unique_templates):
+
+ (ind_t,) = np.where(templates == uqt)
+ face_norm_feats = img_feats[ind_t]
+ face_medias = medias[ind_t]
+ unique_medias, unique_media_counts = np.unique(face_medias, return_counts=True)
+ media_norm_feats = []
+ for u, ct in zip(unique_medias, unique_media_counts):
+ (ind_m,) = np.where(face_medias == u)
+ if ct == 1:
+ media_norm_feats += [face_norm_feats[ind_m]]
+ else: # image features from the same video will be aggregated into one feature
+ media_norm_feats += [np.mean(face_norm_feats[ind_m], axis=0, keepdims=True)]
+ media_norm_feats = np.array(media_norm_feats)
+ # media_norm_feats = media_norm_feats / np.sqrt(np.sum(media_norm_feats ** 2, -1, keepdims=True))
+ template_feats[count_template] = np.sum(media_norm_feats, axis=0)
+ if count_template % 2000 == 0:
+ print("Finish Calculating {} template features.".format(count_template))
+ # template_norm_feats = template_feats / np.sqrt(np.sum(template_feats ** 2, -1, keepdims=True))
+ template_norm_feats = sklearn.preprocessing.normalize(template_feats)
+ # print(template_norm_feats.shape)
+ return template_norm_feats, unique_templates
+
+
+# In[ ]:
+
+
+def verification(template_norm_feats=None, unique_templates=None, p1=None, p2=None):
+ # ==========================================================
+ # Compute set-to-set Similarity Score.
+ # ==========================================================
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+ for count_template, uqt in enumerate(unique_templates):
+ template2id[uqt] = count_template
+
+ score = np.zeros((len(p1),)) # save cosine distance between pairs
+
+ total_pairs = np.array(range(len(p1)))
+ batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
+ sublists = [total_pairs[i : i + batchsize] for i in range(0, len(p1), batchsize)]
+ total_sublists = len(sublists)
+ for c, s in enumerate(sublists):
+ feat1 = template_norm_feats[template2id[p1[s]]]
+ feat2 = template_norm_feats[template2id[p2[s]]]
+ similarity_score = np.sum(feat1 * feat2, -1)
+ score[s] = similarity_score.flatten()
+ if c % 10 == 0:
+ print("Finish {}/{} pairs.".format(c, total_sublists))
+ return score
+
+
+# In[ ]:
+def verification2(template_norm_feats=None, unique_templates=None, p1=None, p2=None):
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+ for count_template, uqt in enumerate(unique_templates):
+ template2id[uqt] = count_template
+ score = np.zeros((len(p1),)) # save cosine distance between pairs
+ total_pairs = np.array(range(len(p1)))
+ batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
+ sublists = [total_pairs[i : i + batchsize] for i in range(0, len(p1), batchsize)]
+ total_sublists = len(sublists)
+ for c, s in enumerate(sublists):
+ feat1 = template_norm_feats[template2id[p1[s]]]
+ feat2 = template_norm_feats[template2id[p2[s]]]
+ similarity_score = np.sum(feat1 * feat2, -1)
+ score[s] = similarity_score.flatten()
+ if c % 10 == 0:
+ print("Finish {}/{} pairs.".format(c, total_sublists))
+ return score
+
+
+def read_score(path):
+ with open(path, "rb") as fid:
+ img_feats = pickle.load(fid)
+ return img_feats
+
+
+# # Step1: Load Meta Data
+
+# In[ ]:
+
+assert target == "IJBC" or target == "IJBB"
+
+# =============================================================
+# load image and template relationships for template feature embedding
+# tid --> template id, mid --> media id
+# format:
+# image_name tid mid
+# =============================================================
+start = timeit.default_timer()
+templates, medias = read_template_media_list(
+ os.path.join("%s/meta" % image_path, "%s_face_tid_mid.txt" % target.lower())
+)
+stop = timeit.default_timer()
+print("Time: %.2f s. " % (stop - start))
+
+# In[ ]:
+
+# =============================================================
+# load template pairs for template-to-template verification
+# tid : template id, label : 1/0
+# format:
+# tid_1 tid_2 label
+# =============================================================
+start = timeit.default_timer()
+p1, p2, label = read_template_pair_list(
+ os.path.join("%s/meta" % image_path, "%s_template_pair_label.txt" % target.lower())
+)
+stop = timeit.default_timer()
+print("Time: %.2f s. " % (stop - start))
+
+# # Step 2: Get Image Features
+
+# In[ ]:
+
+# =============================================================
+# load image features
+# format:
+# img_feats: [image_num x feats_dim] (227630, 512)
+# =============================================================
+start = timeit.default_timer()
+img_path = "%s/loose_crop" % image_path
+img_list_path = "%s/meta/%s_name_5pts_score.txt" % (image_path, target.lower())
+img_list = open(img_list_path)
+files = img_list.readlines()
+# files_list = divideIntoNstrand(files, rank_size)
+files_list = files
+
+# img_feats
+# for i in range(rank_size):
+img_feats, faceness_scores = get_image_feature(img_path, files_list, model_path, 0, gpu_id)
+stop = timeit.default_timer()
+print("Time: %.2f s. " % (stop - start))
+print("Feature Shape: ({} , {}) .".format(img_feats.shape[0], img_feats.shape[1]))
+
+# # Step3: Get Template Features
+
+# In[ ]:
+
+# =============================================================
+# compute template features from image features.
+# =============================================================
+start = timeit.default_timer()
+# ==========================================================
+# Norm feature before aggregation into template feature?
+# Feature norm from embedding network and faceness score are able to decrease weights for noise samples (not face).
+# ==========================================================
+# 1. FaceScore (Feature Norm)
+# 2. FaceScore (Detector)
+
+if use_flip_test:
+ # concat --- F1
+ # img_input_feats = img_feats
+ # add --- F2
+ img_input_feats = img_feats[:, 0 : img_feats.shape[1] // 2] + img_feats[:, img_feats.shape[1] // 2 :]
+else:
+ img_input_feats = img_feats[:, 0 : img_feats.shape[1] // 2]
+
+if use_norm_score:
+ img_input_feats = img_input_feats
+else:
+ # normalise features to remove norm information
+ img_input_feats = img_input_feats / np.sqrt(np.sum(img_input_feats**2, -1, keepdims=True))
+
+if use_detector_score:
+ print(img_input_feats.shape, faceness_scores.shape)
+ img_input_feats = img_input_feats * faceness_scores[:, np.newaxis]
+else:
+ img_input_feats = img_input_feats
+
+template_norm_feats, unique_templates = image2template_feature(img_input_feats, templates, medias)
+stop = timeit.default_timer()
+print("Time: %.2f s. " % (stop - start))
+
+# # Step 4: Get Template Similarity Scores
+
+# In[ ]:
+
+# =============================================================
+# compute verification scores between template pairs.
+# =============================================================
+start = timeit.default_timer()
+score = verification(template_norm_feats, unique_templates, p1, p2)
+stop = timeit.default_timer()
+print("Time: %.2f s. " % (stop - start))
+
+# In[ ]:
+save_path = os.path.join(result_dir, args.job)
+# save_path = result_dir + '/%s_result' % target
+
+if not os.path.exists(save_path):
+ os.makedirs(save_path)
+
+score_save_file = os.path.join(save_path, "%s.npy" % target.lower())
+np.save(score_save_file, score)
+
+# # Step 5: Get ROC Curves and TPR@FPR Table
+
+# In[ ]:
+
+files = [score_save_file]
+methods = []
+scores = []
+for file in files:
+ methods.append(Path(file).stem)
+ scores.append(np.load(file))
+
+methods = np.array(methods)
+scores = dict(zip(methods, scores))
+colours = dict(zip(methods, sample_colours_from_colourmap(methods.shape[0], "Set2")))
+x_labels = [10**-6, 10**-5, 10**-4, 10**-3, 10**-2, 10**-1]
+tpr_fpr_table = PrettyTable(["Methods"] + [str(x) for x in x_labels])
+fig = plt.figure()
+for method in methods:
+ fpr, tpr, _ = roc_curve(label, scores[method])
+ roc_auc = auc(fpr, tpr)
+ fpr = np.flipud(fpr)
+ tpr = np.flipud(tpr) # select largest tpr at same fpr
+ plt.plot(
+ fpr, tpr, color=colours[method], lw=1, label=("[%s (AUC = %0.4f %%)]" % (method.split("-")[-1], roc_auc * 100))
+ )
+ tpr_fpr_row = []
+ tpr_fpr_row.append("%s-%s" % (method, target))
+ for fpr_iter in np.arange(len(x_labels)):
+ _, min_index = min(list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
+ tpr_fpr_row.append("%.2f" % (tpr[min_index] * 100))
+ tpr_fpr_table.add_row(tpr_fpr_row)
+plt.xlim([10**-6, 0.1])
+plt.ylim([0.3, 1.0])
+plt.grid(linestyle="--", linewidth=1)
+plt.xticks(x_labels)
+plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
+plt.xscale("log")
+plt.xlabel("False Positive Rate")
+plt.ylabel("True Positive Rate")
+plt.title("ROC on IJB")
+plt.legend(loc="lower right")
+fig.savefig(os.path.join(save_path, "%s.pdf" % target.lower()))
+print(tpr_fpr_table)
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/flops.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/flops.py
new file mode 100644
index 0000000000000000000000000000000000000000..62aa8ec433846693a0e71e6ab808048ca37e61fd
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/flops.py
@@ -0,0 +1,20 @@
+import argparse
+
+from backbones import get_model
+from ptflops import get_model_complexity_info
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="")
+ parser.add_argument("n", type=str, default="r100")
+ args = parser.parse_args()
+ net = get_model(args.n)
+ macs, params = get_model_complexity_info(
+ net, (3, 112, 112), as_strings=False, print_per_layer_stat=True, verbose=True
+ )
+ gmacs = macs / (1000**3)
+ print("%.3f GFLOPs" % gmacs)
+ print("%.3f Mparams" % (params / (1000**2)))
+
+ if hasattr(net, "extra_gflops"):
+ print("%.3f Extra-GFLOPs" % net.extra_gflops)
+ print("%.3f Total-GFLOPs" % (gmacs + net.extra_gflops))
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/inference.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..1aab06628b4f33a67284ea1446ddc7c38642c33f
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/inference.py
@@ -0,0 +1,34 @@
+import argparse
+
+import cv2
+import numpy as np
+import torch
+from backbones import get_model
+
+
+@torch.no_grad()
+def inference(weight, name, img):
+ if img is None:
+ img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8)
+ else:
+ img = cv2.imread(img)
+ img = cv2.resize(img, (112, 112))
+
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = np.transpose(img, (2, 0, 1))
+ img = torch.from_numpy(img).unsqueeze(0).float()
+ img.div_(255).sub_(0.5).div_(0.5)
+ net = get_model(name, fp16=False)
+ net.load_state_dict(torch.load(weight))
+ net.eval()
+ feat = net(img).numpy()
+ print(feat)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="PyTorch ArcFace Training")
+ parser.add_argument("--network", type=str, default="r50", help="backbone network")
+ parser.add_argument("--weight", type=str, default="")
+ parser.add_argument("--img", type=str, default=None)
+ args = parser.parse_args()
+ inference(args.weight, args.network, args.img)
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/losses.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..7805d8f088e9b91f48b29d8304f87927ca65e0c4
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/losses.py
@@ -0,0 +1,95 @@
+import math
+
+import torch
+
+
+class CombinedMarginLoss(torch.nn.Module):
+ def __init__(self, s, m1, m2, m3, interclass_filtering_threshold=0):
+ super().__init__()
+ self.s = s
+ self.m1 = m1
+ self.m2 = m2
+ self.m3 = m3
+ self.interclass_filtering_threshold = interclass_filtering_threshold
+
+ # For ArcFace
+ self.cos_m = math.cos(self.m2)
+ self.sin_m = math.sin(self.m2)
+ self.theta = math.cos(math.pi - self.m2)
+ self.sinmm = math.sin(math.pi - self.m2) * self.m2
+ self.easy_margin = False
+
+ def forward(self, logits, labels):
+ index_positive = torch.where(labels != -1)[0]
+
+ if self.interclass_filtering_threshold > 0:
+ with torch.no_grad():
+ dirty = logits > self.interclass_filtering_threshold
+ dirty = dirty.float()
+ mask = torch.ones([index_positive.size(0), logits.size(1)], device=logits.device)
+ mask.scatter_(1, labels[index_positive], 0)
+ dirty[index_positive] *= mask
+ tensor_mul = 1 - dirty
+ logits = tensor_mul * logits
+
+ target_logit = logits[index_positive, labels[index_positive].view(-1)]
+
+ if self.m1 == 1.0 and self.m3 == 0.0:
+ with torch.no_grad():
+ target_logit.arccos_()
+ logits.arccos_()
+ final_target_logit = target_logit + self.m2
+ logits[index_positive, labels[index_positive].view(-1)] = final_target_logit
+ logits.cos_()
+ logits = logits * self.s
+
+ elif self.m3 > 0:
+ final_target_logit = target_logit - self.m3
+ logits[index_positive, labels[index_positive].view(-1)] = final_target_logit
+ logits = logits * self.s
+ else:
+ raise
+
+ return logits
+
+
+class ArcFace(torch.nn.Module):
+ """ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):"""
+
+ def __init__(self, s=64.0, margin=0.5):
+ super(ArcFace, self).__init__()
+ self.scale = s
+ self.margin = margin
+ self.cos_m = math.cos(margin)
+ self.sin_m = math.sin(margin)
+ self.theta = math.cos(math.pi - margin)
+ self.sinmm = math.sin(math.pi - margin) * margin
+ self.easy_margin = False
+
+ def forward(self, logits: torch.Tensor, labels: torch.Tensor):
+ index = torch.where(labels != -1)[0]
+ target_logit = logits[index, labels[index].view(-1)]
+
+ with torch.no_grad():
+ target_logit.arccos_()
+ logits.arccos_()
+ final_target_logit = target_logit + self.margin
+ logits[index, labels[index].view(-1)] = final_target_logit
+ logits.cos_()
+ logits = logits * self.s
+ return logits
+
+
+class CosFace(torch.nn.Module):
+ def __init__(self, s=64.0, m=0.40):
+ super(CosFace, self).__init__()
+ self.s = s
+ self.m = m
+
+ def forward(self, logits: torch.Tensor, labels: torch.Tensor):
+ index = torch.where(labels != -1)[0]
+ target_logit = logits[index, labels[index].view(-1)]
+ final_target_logit = target_logit - self.m
+ logits[index, labels[index].view(-1)] = final_target_logit
+ logits = logits * self.s
+ return logits
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/lr_scheduler.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..3020ff343d3333d18cdf9102d6c66be29bab33fa
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/lr_scheduler.py
@@ -0,0 +1,28 @@
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+class PolyScheduler(_LRScheduler):
+ def __init__(self, optimizer, base_lr, max_steps, warmup_steps, last_epoch=-1):
+ self.base_lr = base_lr
+ self.warmup_lr_init = 0.0001
+ self.max_steps: int = max_steps
+ self.warmup_steps: int = warmup_steps
+ self.power = 2
+ super(PolyScheduler, self).__init__(optimizer, -1, False)
+ self.last_epoch = last_epoch
+
+ def get_warmup_lr(self):
+ alpha = float(self.last_epoch) / float(self.warmup_steps)
+ return [self.base_lr * alpha for _ in self.optimizer.param_groups]
+
+ def get_lr(self):
+ if self.last_epoch == -1:
+ return [self.warmup_lr_init for _ in self.optimizer.param_groups]
+ if self.last_epoch < self.warmup_steps:
+ return self.get_warmup_lr()
+ else:
+ alpha = pow(
+ 1 - float(self.last_epoch - self.warmup_steps) / float(self.max_steps - self.warmup_steps),
+ self.power,
+ )
+ return [self.base_lr * alpha for _ in self.optimizer.param_groups]
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/onnx_helper.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/onnx_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..95f615fd7f3e0586be123d9a6538f68386158360
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/onnx_helper.py
@@ -0,0 +1,264 @@
+from __future__ import division
+
+import argparse
+import datetime
+import glob
+import os
+import os.path as osp
+import sys
+
+import cv2
+import numpy as np
+import onnx
+import onnxruntime
+from insightface.data import get_image
+from onnx import numpy_helper
+
+
+class ArcFaceORT:
+ def __init__(self, model_path, cpu=False):
+ self.model_path = model_path
+ # providers = None will use available provider, for onnxruntime-gpu it will be "CUDAExecutionProvider"
+ self.providers = ["CPUExecutionProvider"] if cpu else None
+
+ # input_size is (w,h), return error message, return None if success
+ def check(self, track="cfat", test_img=None):
+ # default is cfat
+ max_model_size_mb = 1024
+ max_feat_dim = 512
+ max_time_cost = 15
+ if track.startswith("ms1m"):
+ max_model_size_mb = 1024
+ max_feat_dim = 512
+ max_time_cost = 10
+ elif track.startswith("glint"):
+ max_model_size_mb = 1024
+ max_feat_dim = 1024
+ max_time_cost = 20
+ elif track.startswith("cfat"):
+ max_model_size_mb = 1024
+ max_feat_dim = 512
+ max_time_cost = 15
+ elif track.startswith("unconstrained"):
+ max_model_size_mb = 1024
+ max_feat_dim = 1024
+ max_time_cost = 30
+ else:
+ return "track not found"
+
+ if not os.path.exists(self.model_path):
+ return "model_path not exists"
+ if not os.path.isdir(self.model_path):
+ return "model_path should be directory"
+ onnx_files = []
+ for _file in os.listdir(self.model_path):
+ if _file.endswith(".onnx"):
+ onnx_files.append(osp.join(self.model_path, _file))
+ if len(onnx_files) == 0:
+ return "do not have onnx files"
+ self.model_file = sorted(onnx_files)[-1]
+ print("use onnx-model:", self.model_file)
+ try:
+ session = onnxruntime.InferenceSession(self.model_file, providers=self.providers)
+ except:
+ return "load onnx failed"
+ input_cfg = session.get_inputs()[0]
+ input_shape = input_cfg.shape
+ print("input-shape:", input_shape)
+ if len(input_shape) != 4:
+ return "length of input_shape should be 4"
+ if not isinstance(input_shape[0], str):
+ # return "input_shape[0] should be str to support batch-inference"
+ print("reset input-shape[0] to None")
+ model = onnx.load(self.model_file)
+ model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = "None"
+ new_model_file = osp.join(self.model_path, "zzzzrefined.onnx")
+ onnx.save(model, new_model_file)
+ self.model_file = new_model_file
+ print("use new onnx-model:", self.model_file)
+ try:
+ session = onnxruntime.InferenceSession(self.model_file, providers=self.providers)
+ except:
+ return "load onnx failed"
+ input_cfg = session.get_inputs()[0]
+ input_shape = input_cfg.shape
+ print("new-input-shape:", input_shape)
+
+ self.image_size = tuple(input_shape[2:4][::-1])
+ # print('image_size:', self.image_size)
+ input_name = input_cfg.name
+ outputs = session.get_outputs()
+ output_names = []
+ for o in outputs:
+ output_names.append(o.name)
+ # print(o.name, o.shape)
+ if len(output_names) != 1:
+ return "number of output nodes should be 1"
+ self.session = session
+ self.input_name = input_name
+ self.output_names = output_names
+ # print(self.output_names)
+ model = onnx.load(self.model_file)
+ graph = model.graph
+ if len(graph.node) < 8:
+ return "too small onnx graph"
+
+ input_size = (112, 112)
+ self.crop = None
+ if track == "cfat":
+ crop_file = osp.join(self.model_path, "crop.txt")
+ if osp.exists(crop_file):
+ lines = open(crop_file, "r").readlines()
+ if len(lines) != 6:
+ return "crop.txt should contain 6 lines"
+ lines = [int(x) for x in lines]
+ self.crop = lines[:4]
+ input_size = tuple(lines[4:6])
+ if input_size != self.image_size:
+ return "input-size is inconsistant with onnx model input, %s vs %s" % (input_size, self.image_size)
+
+ self.model_size_mb = os.path.getsize(self.model_file) / float(1024 * 1024)
+ if self.model_size_mb > max_model_size_mb:
+ return "max model size exceed, given %.3f-MB" % self.model_size_mb
+
+ input_mean = None
+ input_std = None
+ if track == "cfat":
+ pn_file = osp.join(self.model_path, "pixel_norm.txt")
+ if osp.exists(pn_file):
+ lines = open(pn_file, "r").readlines()
+ if len(lines) != 2:
+ return "pixel_norm.txt should contain 2 lines"
+ input_mean = float(lines[0])
+ input_std = float(lines[1])
+ if input_mean is not None or input_std is not None:
+ if input_mean is None or input_std is None:
+ return "please set input_mean and input_std simultaneously"
+ else:
+ find_sub = False
+ find_mul = False
+ for nid, node in enumerate(graph.node[:8]):
+ print(nid, node.name)
+ if node.name.startswith("Sub") or node.name.startswith("_minus"):
+ find_sub = True
+ if node.name.startswith("Mul") or node.name.startswith("_mul") or node.name.startswith("Div"):
+ find_mul = True
+ if find_sub and find_mul:
+ print("find sub and mul")
+ # mxnet arcface model
+ input_mean = 0.0
+ input_std = 1.0
+ else:
+ input_mean = 127.5
+ input_std = 127.5
+ self.input_mean = input_mean
+ self.input_std = input_std
+ for initn in graph.initializer:
+ weight_array = numpy_helper.to_array(initn)
+ dt = weight_array.dtype
+ if dt.itemsize < 4:
+ return "invalid weight type - (%s:%s)" % (initn.name, dt.name)
+ if test_img is None:
+ test_img = get_image("Tom_Hanks_54745")
+ test_img = cv2.resize(test_img, self.image_size)
+ else:
+ test_img = cv2.resize(test_img, self.image_size)
+ feat, cost = self.benchmark(test_img)
+ batch_result = self.check_batch(test_img)
+ batch_result_sum = float(np.sum(batch_result))
+ if batch_result_sum in [float("inf"), -float("inf")] or batch_result_sum != batch_result_sum:
+ print(batch_result)
+ print(batch_result_sum)
+ return "batch result output contains NaN!"
+
+ if len(feat.shape) < 2:
+ return "the shape of the feature must be two, but get {}".format(str(feat.shape))
+
+ if feat.shape[1] > max_feat_dim:
+ return "max feat dim exceed, given %d" % feat.shape[1]
+ self.feat_dim = feat.shape[1]
+ cost_ms = cost * 1000
+ if cost_ms > max_time_cost:
+ return "max time cost exceed, given %.4f" % cost_ms
+ self.cost_ms = cost_ms
+ print(
+ "check stat:, model-size-mb: %.4f, feat-dim: %d, time-cost-ms: %.4f, input-mean: %.3f, input-std: %.3f"
+ % (self.model_size_mb, self.feat_dim, self.cost_ms, self.input_mean, self.input_std)
+ )
+ return None
+
+ def check_batch(self, img):
+ if not isinstance(img, list):
+ imgs = [
+ img,
+ ] * 32
+ if self.crop is not None:
+ nimgs = []
+ for img in imgs:
+ nimg = img[self.crop[1] : self.crop[3], self.crop[0] : self.crop[2], :]
+ if nimg.shape[0] != self.image_size[1] or nimg.shape[1] != self.image_size[0]:
+ nimg = cv2.resize(nimg, self.image_size)
+ nimgs.append(nimg)
+ imgs = nimgs
+ blob = cv2.dnn.blobFromImages(
+ images=imgs,
+ scalefactor=1.0 / self.input_std,
+ size=self.image_size,
+ mean=(self.input_mean, self.input_mean, self.input_mean),
+ swapRB=True,
+ )
+ net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
+ return net_out
+
+ def meta_info(self):
+ return {"model-size-mb": self.model_size_mb, "feature-dim": self.feat_dim, "infer": self.cost_ms}
+
+ def forward(self, imgs):
+ if not isinstance(imgs, list):
+ imgs = [imgs]
+ input_size = self.image_size
+ if self.crop is not None:
+ nimgs = []
+ for img in imgs:
+ nimg = img[self.crop[1] : self.crop[3], self.crop[0] : self.crop[2], :]
+ if nimg.shape[0] != input_size[1] or nimg.shape[1] != input_size[0]:
+ nimg = cv2.resize(nimg, input_size)
+ nimgs.append(nimg)
+ imgs = nimgs
+ blob = cv2.dnn.blobFromImages(
+ imgs, 1.0 / self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True
+ )
+ net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
+ return net_out
+
+ def benchmark(self, img):
+ input_size = self.image_size
+ if self.crop is not None:
+ nimg = img[self.crop[1] : self.crop[3], self.crop[0] : self.crop[2], :]
+ if nimg.shape[0] != input_size[1] or nimg.shape[1] != input_size[0]:
+ nimg = cv2.resize(nimg, input_size)
+ img = nimg
+ blob = cv2.dnn.blobFromImage(
+ img, 1.0 / self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True
+ )
+ costs = []
+ for _ in range(50):
+ ta = datetime.datetime.now()
+ net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
+ tb = datetime.datetime.now()
+ cost = (tb - ta).total_seconds()
+ costs.append(cost)
+ costs = sorted(costs)
+ cost = costs[5]
+ return net_out, cost
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="")
+ # general
+ parser.add_argument("workdir", help="submitted work dir", type=str)
+ parser.add_argument("--track", help="track name, for different challenge", type=str, default="cfat")
+ args = parser.parse_args()
+ handler = ArcFaceORT(args.workdir)
+ err = handler.check(args.track)
+ print("err:", err)
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/onnx_ijbc.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/onnx_ijbc.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d1bff03699c96139f3f9e9b52998cba592d9d72
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/onnx_ijbc.py
@@ -0,0 +1,262 @@
+import argparse
+import os
+import pickle
+import timeit
+
+import cv2
+import mxnet as mx
+import numpy as np
+import pandas as pd
+import prettytable
+import skimage.transform
+import torch
+from onnx_helper import ArcFaceORT
+from sklearn.metrics import roc_curve
+from sklearn.preprocessing import normalize
+from torch.utils.data import DataLoader
+
+SRC = np.array(
+ [[30.2946, 51.6963], [65.5318, 51.5014], [48.0252, 71.7366], [33.5493, 92.3655], [62.7299, 92.2041]],
+ dtype=np.float32,
+)
+SRC[:, 0] += 8.0
+
+
+@torch.no_grad()
+class AlignedDataSet(mx.gluon.data.Dataset):
+ def __init__(self, root, lines, align=True):
+ self.lines = lines
+ self.root = root
+ self.align = align
+
+ def __len__(self):
+ return len(self.lines)
+
+ def __getitem__(self, idx):
+ each_line = self.lines[idx]
+ name_lmk_score = each_line.strip().split(" ")
+ name = os.path.join(self.root, name_lmk_score[0])
+ img = cv2.cvtColor(cv2.imread(name), cv2.COLOR_BGR2RGB)
+ landmark5 = np.array([float(x) for x in name_lmk_score[1:-1]], dtype=np.float32).reshape((5, 2))
+ st = skimage.transform.SimilarityTransform()
+ st.estimate(landmark5, SRC)
+ img = cv2.warpAffine(img, st.params[0:2, :], (112, 112), borderValue=0.0)
+ img_1 = np.expand_dims(img, 0)
+ img_2 = np.expand_dims(np.fliplr(img), 0)
+ output = np.concatenate((img_1, img_2), axis=0).astype(np.float32)
+ output = np.transpose(output, (0, 3, 1, 2))
+ return torch.from_numpy(output)
+
+
+@torch.no_grad()
+def extract(model_root, dataset):
+ model = ArcFaceORT(model_path=model_root)
+ model.check()
+ feat_mat = np.zeros(shape=(len(dataset), 2 * model.feat_dim))
+
+ def collate_fn(data):
+ return torch.cat(data, dim=0)
+
+ data_loader = DataLoader(
+ dataset,
+ batch_size=128,
+ drop_last=False,
+ num_workers=4,
+ collate_fn=collate_fn,
+ )
+ num_iter = 0
+ for batch in data_loader:
+ batch = batch.numpy()
+ batch = (batch - model.input_mean) / model.input_std
+ feat = model.session.run(model.output_names, {model.input_name: batch})[0]
+ feat = np.reshape(feat, (-1, model.feat_dim * 2))
+ feat_mat[128 * num_iter : 128 * num_iter + feat.shape[0], :] = feat
+ num_iter += 1
+ if num_iter % 50 == 0:
+ print(num_iter)
+ return feat_mat
+
+
+def read_template_media_list(path):
+ ijb_meta = pd.read_csv(path, sep=" ", header=None).values
+ templates = ijb_meta[:, 1].astype(np.int)
+ medias = ijb_meta[:, 2].astype(np.int)
+ return templates, medias
+
+
+def read_template_pair_list(path):
+ pairs = pd.read_csv(path, sep=" ", header=None).values
+ t1 = pairs[:, 0].astype(np.int)
+ t2 = pairs[:, 1].astype(np.int)
+ label = pairs[:, 2].astype(np.int)
+ return t1, t2, label
+
+
+def read_image_feature(path):
+ with open(path, "rb") as fid:
+ img_feats = pickle.load(fid)
+ return img_feats
+
+
+def image2template_feature(img_feats=None, templates=None, medias=None):
+ unique_templates = np.unique(templates)
+ template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))
+ for count_template, uqt in enumerate(unique_templates):
+ (ind_t,) = np.where(templates == uqt)
+ face_norm_feats = img_feats[ind_t]
+ face_medias = medias[ind_t]
+ unique_medias, unique_media_counts = np.unique(face_medias, return_counts=True)
+ media_norm_feats = []
+ for u, ct in zip(unique_medias, unique_media_counts):
+ (ind_m,) = np.where(face_medias == u)
+ if ct == 1:
+ media_norm_feats += [face_norm_feats[ind_m]]
+ else: # image features from the same video will be aggregated into one feature
+ media_norm_feats += [
+ np.mean(face_norm_feats[ind_m], axis=0, keepdims=True),
+ ]
+ media_norm_feats = np.array(media_norm_feats)
+ template_feats[count_template] = np.sum(media_norm_feats, axis=0)
+ if count_template % 2000 == 0:
+ print("Finish Calculating {} template features.".format(count_template))
+ template_norm_feats = normalize(template_feats)
+ return template_norm_feats, unique_templates
+
+
+def verification(template_norm_feats=None, unique_templates=None, p1=None, p2=None):
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+ for count_template, uqt in enumerate(unique_templates):
+ template2id[uqt] = count_template
+ score = np.zeros((len(p1),))
+ total_pairs = np.array(range(len(p1)))
+ batchsize = 100000
+ sublists = [total_pairs[i : i + batchsize] for i in range(0, len(p1), batchsize)]
+ total_sublists = len(sublists)
+ for c, s in enumerate(sublists):
+ feat1 = template_norm_feats[template2id[p1[s]]]
+ feat2 = template_norm_feats[template2id[p2[s]]]
+ similarity_score = np.sum(feat1 * feat2, -1)
+ score[s] = similarity_score.flatten()
+ if c % 10 == 0:
+ print("Finish {}/{} pairs.".format(c, total_sublists))
+ return score
+
+
+def verification2(template_norm_feats=None, unique_templates=None, p1=None, p2=None):
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+ for count_template, uqt in enumerate(unique_templates):
+ template2id[uqt] = count_template
+ score = np.zeros((len(p1),)) # save cosine distance between pairs
+ total_pairs = np.array(range(len(p1)))
+ batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
+ sublists = [total_pairs[i : i + batchsize] for i in range(0, len(p1), batchsize)]
+ total_sublists = len(sublists)
+ for c, s in enumerate(sublists):
+ feat1 = template_norm_feats[template2id[p1[s]]]
+ feat2 = template_norm_feats[template2id[p2[s]]]
+ similarity_score = np.sum(feat1 * feat2, -1)
+ score[s] = similarity_score.flatten()
+ if c % 10 == 0:
+ print("Finish {}/{} pairs.".format(c, total_sublists))
+ return score
+
+
+def main(args):
+ use_norm_score = True # if Ture, TestMode(N1)
+ use_detector_score = True # if Ture, TestMode(D1)
+ use_flip_test = True # if Ture, TestMode(F1)
+ assert args.target == "IJBC" or args.target == "IJBB"
+
+ start = timeit.default_timer()
+ templates, medias = read_template_media_list(
+ os.path.join("%s/meta" % args.image_path, "%s_face_tid_mid.txt" % args.target.lower())
+ )
+ stop = timeit.default_timer()
+ print("Time: %.2f s. " % (stop - start))
+
+ start = timeit.default_timer()
+ p1, p2, label = read_template_pair_list(
+ os.path.join("%s/meta" % args.image_path, "%s_template_pair_label.txt" % args.target.lower())
+ )
+ stop = timeit.default_timer()
+ print("Time: %.2f s. " % (stop - start))
+
+ start = timeit.default_timer()
+ img_path = "%s/loose_crop" % args.image_path
+ img_list_path = "%s/meta/%s_name_5pts_score.txt" % (args.image_path, args.target.lower())
+ img_list = open(img_list_path)
+ files = img_list.readlines()
+ dataset = AlignedDataSet(root=img_path, lines=files, align=True)
+ img_feats = extract(args.model_root, dataset)
+
+ faceness_scores = []
+ for each_line in files:
+ name_lmk_score = each_line.split()
+ faceness_scores.append(name_lmk_score[-1])
+ faceness_scores = np.array(faceness_scores).astype(np.float32)
+ stop = timeit.default_timer()
+ print("Time: %.2f s. " % (stop - start))
+ print("Feature Shape: ({} , {}) .".format(img_feats.shape[0], img_feats.shape[1]))
+ start = timeit.default_timer()
+
+ if use_flip_test:
+ img_input_feats = img_feats[:, 0 : img_feats.shape[1] // 2] + img_feats[:, img_feats.shape[1] // 2 :]
+ else:
+ img_input_feats = img_feats[:, 0 : img_feats.shape[1] // 2]
+
+ if use_norm_score:
+ img_input_feats = img_input_feats
+ else:
+ img_input_feats = img_input_feats / np.sqrt(np.sum(img_input_feats**2, -1, keepdims=True))
+
+ if use_detector_score:
+ print(img_input_feats.shape, faceness_scores.shape)
+ img_input_feats = img_input_feats * faceness_scores[:, np.newaxis]
+ else:
+ img_input_feats = img_input_feats
+
+ template_norm_feats, unique_templates = image2template_feature(img_input_feats, templates, medias)
+ stop = timeit.default_timer()
+ print("Time: %.2f s. " % (stop - start))
+
+ start = timeit.default_timer()
+ score = verification(template_norm_feats, unique_templates, p1, p2)
+ stop = timeit.default_timer()
+ print("Time: %.2f s. " % (stop - start))
+ result_dir = args.model_root
+
+ save_path = os.path.join(result_dir, "{}_result".format(args.target))
+ if not os.path.exists(save_path):
+ os.makedirs(save_path)
+ score_save_file = os.path.join(save_path, "{}.npy".format(args.target))
+ np.save(score_save_file, score)
+ files = [score_save_file]
+ methods = []
+ scores = []
+ for file in files:
+ methods.append(os.path.basename(file))
+ scores.append(np.load(file))
+ methods = np.array(methods)
+ scores = dict(zip(methods, scores))
+ x_labels = [10**-6, 10**-5, 10**-4, 10**-3, 10**-2, 10**-1]
+ tpr_fpr_table = prettytable.PrettyTable(["Methods"] + [str(x) for x in x_labels])
+ for method in methods:
+ fpr, tpr, _ = roc_curve(label, scores[method])
+ fpr = np.flipud(fpr)
+ tpr = np.flipud(tpr)
+ tpr_fpr_row = []
+ tpr_fpr_row.append("%s-%s" % (method, args.target))
+ for fpr_iter in np.arange(len(x_labels)):
+ _, min_index = min(list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
+ tpr_fpr_row.append("%.2f" % (tpr[min_index] * 100))
+ tpr_fpr_table.add_row(tpr_fpr_row)
+ print(tpr_fpr_table)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="do ijb test")
+ # general
+ parser.add_argument("--model-root", default="", help="path to load model.")
+ parser.add_argument("--image-path", default="/train_tmp/IJB_release/IJBC", type=str, help="")
+ parser.add_argument("--target", default="IJBC", type=str, help="target, set to IJBC or IJBB")
+ main(parser.parse_args())
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/partial_fc.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/partial_fc.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7891527d6c396a6b51a67daf06593d4db5cce43
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/partial_fc.py
@@ -0,0 +1,490 @@
+import collections
+from typing import Callable
+
+import torch
+from torch import distributed
+from torch.nn.functional import linear
+from torch.nn.functional import normalize
+
+
+class PartialFC(torch.nn.Module):
+ """
+ https://arxiv.org/abs/2203.15565
+ A distributed sparsely updating variant of the FC layer, named Partial FC (PFC).
+
+ When sample rate less than 1, in each iteration, positive class centers and a random subset of
+ negative class centers are selected to compute the margin-based softmax loss, all class
+ centers are still maintained throughout the whole training process, but only a subset is
+ selected and updated in each iteration.
+
+ .. note::
+ When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1).
+
+ Example:
+ --------
+ >>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2)
+ >>> for img, labels in data_loader:
+ >>> embeddings = net(img)
+ >>> loss = module_pfc(embeddings, labels, optimizer)
+ >>> loss.backward()
+ >>> optimizer.step()
+ """
+
+ _version = 1
+
+ def __init__(
+ self,
+ margin_loss: Callable,
+ embedding_size: int,
+ num_classes: int,
+ sample_rate: float = 1.0,
+ fp16: bool = False,
+ ):
+ """
+ Paramenters:
+ -----------
+ embedding_size: int
+ The dimension of embedding, required
+ num_classes: int
+ Total number of classes, required
+ sample_rate: float
+ The rate of negative centers participating in the calculation, default is 1.0.
+ """
+ super(PartialFC, self).__init__()
+ assert distributed.is_initialized(), "must initialize distributed before create this"
+ self.rank = distributed.get_rank()
+ self.world_size = distributed.get_world_size()
+
+ self.dist_cross_entropy = DistCrossEntropy()
+ self.embedding_size = embedding_size
+ self.sample_rate: float = sample_rate
+ self.fp16 = fp16
+ self.num_local: int = num_classes // self.world_size + int(self.rank < num_classes % self.world_size)
+ self.class_start: int = num_classes // self.world_size * self.rank + min(
+ self.rank, num_classes % self.world_size
+ )
+ self.num_sample: int = int(self.sample_rate * self.num_local)
+ self.last_batch_size: int = 0
+ self.weight: torch.Tensor
+ self.weight_mom: torch.Tensor
+ self.weight_activated: torch.nn.Parameter
+ self.weight_activated_mom: torch.Tensor
+ self.is_updated: bool = True
+ self.init_weight_update: bool = True
+
+ if self.sample_rate < 1:
+ self.register_buffer("weight", tensor=torch.normal(0, 0.01, (self.num_local, embedding_size)))
+ self.register_buffer("weight_mom", tensor=torch.zeros_like(self.weight))
+ self.register_parameter("weight_activated", param=torch.nn.Parameter(torch.empty(0, 0)))
+ self.register_buffer("weight_activated_mom", tensor=torch.empty(0, 0))
+ self.register_buffer("weight_index", tensor=torch.empty(0, 0))
+ else:
+ self.weight_activated = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size)))
+
+ # margin_loss
+ if isinstance(margin_loss, Callable):
+ self.margin_softmax = margin_loss
+ else:
+ raise
+
+ @torch.no_grad()
+ def sample(self, labels: torch.Tensor, index_positive: torch.Tensor, optimizer: torch.optim.Optimizer):
+ """
+ This functions will change the value of labels
+
+ Parameters:
+ -----------
+ labels: torch.Tensor
+ pass
+ index_positive: torch.Tensor
+ pass
+ optimizer: torch.optim.Optimizer
+ pass
+ """
+ positive = torch.unique(labels[index_positive], sorted=True).cuda()
+ if self.num_sample - positive.size(0) >= 0:
+ perm = torch.rand(size=[self.num_local]).cuda()
+ perm[positive] = 2.0
+ index = torch.topk(perm, k=self.num_sample)[1].cuda()
+ index = index.sort()[0].cuda()
+ else:
+ index = positive
+ self.weight_index = index
+
+ labels[index_positive] = torch.searchsorted(index, labels[index_positive])
+
+ self.weight_activated = torch.nn.Parameter(self.weight[self.weight_index])
+ self.weight_activated_mom = self.weight_mom[self.weight_index]
+
+ if isinstance(optimizer, torch.optim.SGD):
+ # TODO the params of partial fc must be last in the params list
+ optimizer.state.pop(optimizer.param_groups[-1]["params"][0], None)
+ optimizer.param_groups[-1]["params"][0] = self.weight_activated
+ optimizer.state[self.weight_activated]["momentum_buffer"] = self.weight_activated_mom
+ else:
+ raise
+
+ @torch.no_grad()
+ def update(self):
+ """partial weight to global"""
+ if self.init_weight_update:
+ self.init_weight_update = False
+ return
+
+ if self.sample_rate < 1:
+ self.weight[self.weight_index] = self.weight_activated
+ self.weight_mom[self.weight_index] = self.weight_activated_mom
+
+ def forward(
+ self,
+ local_embeddings: torch.Tensor,
+ local_labels: torch.Tensor,
+ optimizer: torch.optim.Optimizer,
+ ):
+ """
+ Parameters:
+ ----------
+ local_embeddings: torch.Tensor
+ feature embeddings on each GPU(Rank).
+ local_labels: torch.Tensor
+ labels on each GPU(Rank).
+
+ Returns:
+ -------
+ loss: torch.Tensor
+ pass
+ """
+ local_labels.squeeze_()
+ local_labels = local_labels.long()
+ self.update()
+
+ batch_size = local_embeddings.size(0)
+ if self.last_batch_size == 0:
+ self.last_batch_size = batch_size
+ assert self.last_batch_size == batch_size, "last batch size do not equal current batch size: {} vs {}".format(
+ self.last_batch_size, batch_size
+ )
+
+ _gather_embeddings = [torch.zeros((batch_size, self.embedding_size)).cuda() for _ in range(self.world_size)]
+ _gather_labels = [torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)]
+ _list_embeddings = AllGather(local_embeddings, *_gather_embeddings)
+ distributed.all_gather(_gather_labels, local_labels)
+
+ embeddings = torch.cat(_list_embeddings)
+ labels = torch.cat(_gather_labels)
+
+ labels = labels.view(-1, 1)
+ index_positive = (self.class_start <= labels) & (labels < self.class_start + self.num_local)
+ labels[~index_positive] = -1
+ labels[index_positive] -= self.class_start
+
+ if self.sample_rate < 1:
+ self.sample(labels, index_positive, optimizer)
+
+ with torch.cuda.amp.autocast(self.fp16):
+ norm_embeddings = normalize(embeddings)
+ norm_weight_activated = normalize(self.weight_activated)
+ logits = linear(norm_embeddings, norm_weight_activated)
+ if self.fp16:
+ logits = logits.float()
+ logits = logits.clamp(-1, 1)
+
+ logits = self.margin_softmax(logits, labels)
+ loss = self.dist_cross_entropy(logits, labels)
+ return loss
+
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
+ if destination is None:
+ destination = collections.OrderedDict()
+ destination._metadata = collections.OrderedDict()
+
+ for name, module in self._modules.items():
+ if module is not None:
+ module.state_dict(destination, prefix + name + ".", keep_vars=keep_vars)
+ if self.sample_rate < 1:
+ destination["weight"] = self.weight.detach()
+ else:
+ destination["weight"] = self.weight_activated.data.detach()
+ return destination
+
+ def load_state_dict(self, state_dict, strict: bool = True):
+ if self.sample_rate < 1:
+ self.weight = state_dict["weight"].to(self.weight.device)
+ self.weight_mom.zero_()
+ self.weight_activated.data.zero_()
+ self.weight_activated_mom.zero_()
+ self.weight_index.zero_()
+ else:
+ self.weight_activated.data = state_dict["weight"].to(self.weight_activated.data.device)
+
+
+class PartialFCAdamW(torch.nn.Module):
+ def __init__(
+ self,
+ margin_loss: Callable,
+ embedding_size: int,
+ num_classes: int,
+ sample_rate: float = 1.0,
+ fp16: bool = False,
+ ):
+ """
+ Paramenters:
+ -----------
+ embedding_size: int
+ The dimension of embedding, required
+ num_classes: int
+ Total number of classes, required
+ sample_rate: float
+ The rate of negative centers participating in the calculation, default is 1.0.
+ """
+ super(PartialFCAdamW, self).__init__()
+ assert distributed.is_initialized(), "must initialize distributed before create this"
+ self.rank = distributed.get_rank()
+ self.world_size = distributed.get_world_size()
+
+ self.dist_cross_entropy = DistCrossEntropy()
+ self.embedding_size = embedding_size
+ self.sample_rate: float = sample_rate
+ self.fp16 = fp16
+ self.num_local: int = num_classes // self.world_size + int(self.rank < num_classes % self.world_size)
+ self.class_start: int = num_classes // self.world_size * self.rank + min(
+ self.rank, num_classes % self.world_size
+ )
+ self.num_sample: int = int(self.sample_rate * self.num_local)
+ self.last_batch_size: int = 0
+ self.weight: torch.Tensor
+ self.weight_exp_avg: torch.Tensor
+ self.weight_exp_avg_sq: torch.Tensor
+ self.weight_activated: torch.nn.Parameter
+ self.weight_activated_exp_avg: torch.Tensor
+ self.weight_activated_exp_avg_sq: torch.Tensor
+
+ self.is_updated: bool = True
+ self.init_weight_update: bool = True
+
+ if self.sample_rate < 1:
+ self.register_buffer("weight", tensor=torch.normal(0, 0.01, (self.num_local, embedding_size)))
+ self.register_buffer("weight_exp_avg", tensor=torch.zeros_like(self.weight))
+ self.register_buffer("weight_exp_avg_sq", tensor=torch.zeros_like(self.weight))
+ self.register_parameter("weight_activated", param=torch.nn.Parameter(torch.empty(0, 0)))
+ self.register_buffer("weight_activated_exp_avg", tensor=torch.empty(0, 0))
+ self.register_buffer("weight_activated_exp_avg_sq", tensor=torch.empty(0, 0))
+ else:
+ self.weight_activated = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size)))
+ self.step = 0
+
+ if isinstance(margin_loss, Callable):
+ self.margin_softmax = margin_loss
+ else:
+ raise
+
+ @torch.no_grad()
+ def sample(self, labels, index_positive, optimizer):
+ self.step += 1
+ positive = torch.unique(labels[index_positive], sorted=True).cuda()
+ if self.num_sample - positive.size(0) >= 0:
+ perm = torch.rand(size=[self.num_local]).cuda()
+ perm[positive] = 2.0
+ index = torch.topk(perm, k=self.num_sample)[1].cuda()
+ index = index.sort()[0].cuda()
+ else:
+ index = positive
+ self.weight_index = index
+ labels[index_positive] = torch.searchsorted(index, labels[index_positive])
+ self.weight_activated = torch.nn.Parameter(self.weight[self.weight_index])
+ self.weight_activated_exp_avg = self.weight_exp_avg[self.weight_index]
+ self.weight_activated_exp_avg_sq = self.weight_exp_avg_sq[self.weight_index]
+
+ if isinstance(optimizer, (torch.optim.Adam, torch.optim.AdamW)):
+ # TODO the params of partial fc must be last in the params list
+ optimizer.state.pop(optimizer.param_groups[-1]["params"][0], None)
+ optimizer.param_groups[-1]["params"][0] = self.weight_activated
+ optimizer.state[self.weight_activated]["exp_avg"] = self.weight_activated_exp_avg
+ optimizer.state[self.weight_activated]["exp_avg_sq"] = self.weight_activated_exp_avg_sq
+ optimizer.state[self.weight_activated]["step"] = self.step
+ else:
+ raise
+
+ @torch.no_grad()
+ def update(self):
+ """partial weight to global"""
+ if self.init_weight_update:
+ self.init_weight_update = False
+ return
+
+ if self.sample_rate < 1:
+ self.weight[self.weight_index] = self.weight_activated
+ self.weight_exp_avg[self.weight_index] = self.weight_activated_exp_avg
+ self.weight_exp_avg_sq[self.weight_index] = self.weight_activated_exp_avg_sq
+
+ def forward(
+ self,
+ local_embeddings: torch.Tensor,
+ local_labels: torch.Tensor,
+ optimizer: torch.optim.Optimizer,
+ ):
+ """
+ Parameters:
+ ----------
+ local_embeddings: torch.Tensor
+ feature embeddings on each GPU(Rank).
+ local_labels: torch.Tensor
+ labels on each GPU(Rank).
+
+ Returns:
+ -------
+ loss: torch.Tensor
+ pass
+ """
+ local_labels.squeeze_()
+ local_labels = local_labels.long()
+ self.update()
+
+ batch_size = local_embeddings.size(0)
+ if self.last_batch_size == 0:
+ self.last_batch_size = batch_size
+ assert self.last_batch_size == batch_size, "last batch size do not equal current batch size: {} vs {}".format(
+ self.last_batch_size, batch_size
+ )
+
+ _gather_embeddings = [torch.zeros((batch_size, self.embedding_size)).cuda() for _ in range(self.world_size)]
+ _gather_labels = [torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)]
+ _list_embeddings = AllGather(local_embeddings, *_gather_embeddings)
+ distributed.all_gather(_gather_labels, local_labels)
+
+ embeddings = torch.cat(_list_embeddings)
+ labels = torch.cat(_gather_labels)
+
+ labels = labels.view(-1, 1)
+ index_positive = (self.class_start <= labels) & (labels < self.class_start + self.num_local)
+ labels[~index_positive] = -1
+ labels[index_positive] -= self.class_start
+
+ if self.sample_rate < 1:
+ self.sample(labels, index_positive, optimizer)
+
+ with torch.cuda.amp.autocast(self.fp16):
+ norm_embeddings = normalize(embeddings)
+ norm_weight_activated = normalize(self.weight_activated)
+ logits = linear(norm_embeddings, norm_weight_activated)
+ if self.fp16:
+ logits = logits.float()
+ logits = logits.clamp(-1, 1)
+
+ logits = self.margin_softmax(logits, labels)
+ loss = self.dist_cross_entropy(logits, labels)
+ return loss
+
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
+ if destination is None:
+ destination = collections.OrderedDict()
+ destination._metadata = collections.OrderedDict()
+
+ for name, module in self._modules.items():
+ if module is not None:
+ module.state_dict(destination, prefix + name + ".", keep_vars=keep_vars)
+ if self.sample_rate < 1:
+ destination["weight"] = self.weight.detach()
+ else:
+ destination["weight"] = self.weight_activated.data.detach()
+ return destination
+
+ def load_state_dict(self, state_dict, strict: bool = True):
+ if self.sample_rate < 1:
+ self.weight = state_dict["weight"].to(self.weight.device)
+ self.weight_exp_avg.zero_()
+ self.weight_exp_avg_sq.zero_()
+ self.weight_activated.data.zero_()
+ self.weight_activated_exp_avg.zero_()
+ self.weight_activated_exp_avg_sq.zero_()
+ else:
+ self.weight_activated.data = state_dict["weight"].to(self.weight_activated.data.device)
+
+
+class DistCrossEntropyFunc(torch.autograd.Function):
+ """
+ CrossEntropy loss is calculated in parallel, allreduce denominator into single gpu and calculate softmax.
+ Implemented of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
+ """
+
+ @staticmethod
+ def forward(ctx, logits: torch.Tensor, label: torch.Tensor):
+ """ """
+ batch_size = logits.size(0)
+ # for numerical stability
+ max_logits, _ = torch.max(logits, dim=1, keepdim=True)
+ # local to global
+ distributed.all_reduce(max_logits, distributed.ReduceOp.MAX)
+ logits.sub_(max_logits)
+ logits.exp_()
+ sum_logits_exp = torch.sum(logits, dim=1, keepdim=True)
+ # local to global
+ distributed.all_reduce(sum_logits_exp, distributed.ReduceOp.SUM)
+ logits.div_(sum_logits_exp)
+ index = torch.where(label != -1)[0]
+ # loss
+ loss = torch.zeros(batch_size, 1, device=logits.device)
+ loss[index] = logits[index].gather(1, label[index])
+ distributed.all_reduce(loss, distributed.ReduceOp.SUM)
+ ctx.save_for_backward(index, logits, label)
+ return loss.clamp_min_(1e-30).log_().mean() * (-1)
+
+ @staticmethod
+ def backward(ctx, loss_gradient):
+ """
+ Args:
+ loss_grad (torch.Tensor): gradient backward by last layer
+ Returns:
+ gradients for each input in forward function
+ `None` gradients for one-hot label
+ """
+ (
+ index,
+ logits,
+ label,
+ ) = ctx.saved_tensors
+ batch_size = logits.size(0)
+ one_hot = torch.zeros(size=[index.size(0), logits.size(1)], device=logits.device)
+ one_hot.scatter_(1, label[index], 1)
+ logits[index] -= one_hot
+ logits.div_(batch_size)
+ return logits * loss_gradient.item(), None
+
+
+class DistCrossEntropy(torch.nn.Module):
+ def __init__(self):
+ super(DistCrossEntropy, self).__init__()
+
+ def forward(self, logit_part, label_part):
+ return DistCrossEntropyFunc.apply(logit_part, label_part)
+
+
+class AllGatherFunc(torch.autograd.Function):
+ """AllGather op with gradient backward"""
+
+ @staticmethod
+ def forward(ctx, tensor, *gather_list):
+ gather_list = list(gather_list)
+ distributed.all_gather(gather_list, tensor)
+ return tuple(gather_list)
+
+ @staticmethod
+ def backward(ctx, *grads):
+ grad_list = list(grads)
+ rank = distributed.get_rank()
+ grad_out = grad_list[rank]
+
+ dist_ops = [
+ distributed.reduce(grad_out, rank, distributed.ReduceOp.SUM, async_op=True)
+ if i == rank
+ else distributed.reduce(grad_list[i], i, distributed.ReduceOp.SUM, async_op=True)
+ for i in range(distributed.get_world_size())
+ ]
+ for _op in dist_ops:
+ _op.wait()
+
+ grad_out *= len(grad_list) # cooperate with distributed loss function
+ return (grad_out, *[None for _ in range(len(grad_list))])
+
+
+AllGather = AllGatherFunc.apply
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/partial_fc_v2.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/partial_fc_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..45078e430a6b0cd442ff65618093689822711aef
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/partial_fc_v2.py
@@ -0,0 +1,247 @@
+import math
+from typing import Callable
+
+import torch
+from torch import distributed
+from torch.nn.functional import linear
+from torch.nn.functional import normalize
+
+
+class PartialFC_V2(torch.nn.Module):
+ """
+ https://arxiv.org/abs/2203.15565
+ A distributed sparsely updating variant of the FC layer, named Partial FC (PFC).
+ When sample rate less than 1, in each iteration, positive class centers and a random subset of
+ negative class centers are selected to compute the margin-based softmax loss, all class
+ centers are still maintained throughout the whole training process, but only a subset is
+ selected and updated in each iteration.
+ .. note::
+ When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1).
+ Example:
+ --------
+ >>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2)
+ >>> for img, labels in data_loader:
+ >>> embeddings = net(img)
+ >>> loss = module_pfc(embeddings, labels)
+ >>> loss.backward()
+ >>> optimizer.step()
+ """
+
+ _version = 2
+
+ def __init__(
+ self,
+ margin_loss: Callable,
+ embedding_size: int,
+ num_classes: int,
+ sample_rate: float = 1.0,
+ fp16: bool = False,
+ ):
+ """
+ Paramenters:
+ -----------
+ embedding_size: int
+ The dimension of embedding, required
+ num_classes: int
+ Total number of classes, required
+ sample_rate: float
+ The rate of negative centers participating in the calculation, default is 1.0.
+ """
+ super(PartialFC_V2, self).__init__()
+ assert distributed.is_initialized(), "must initialize distributed before create this"
+ self.rank = distributed.get_rank()
+ self.world_size = distributed.get_world_size()
+
+ self.dist_cross_entropy = DistCrossEntropy()
+ self.embedding_size = embedding_size
+ self.sample_rate: float = sample_rate
+ self.fp16 = fp16
+ self.num_local: int = num_classes // self.world_size + int(self.rank < num_classes % self.world_size)
+ self.class_start: int = num_classes // self.world_size * self.rank + min(
+ self.rank, num_classes % self.world_size
+ )
+ self.num_sample: int = int(self.sample_rate * self.num_local)
+ self.last_batch_size: int = 0
+
+ self.is_updated: bool = True
+ self.init_weight_update: bool = True
+ self.weight = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size)))
+
+ # margin_loss
+ if isinstance(margin_loss, Callable):
+ self.margin_softmax = margin_loss
+ else:
+ raise
+
+ def sample(self, labels, index_positive):
+ """
+ This functions will change the value of labels
+ Parameters:
+ -----------
+ labels: torch.Tensor
+ pass
+ index_positive: torch.Tensor
+ pass
+ optimizer: torch.optim.Optimizer
+ pass
+ """
+ with torch.no_grad():
+ positive = torch.unique(labels[index_positive], sorted=True).cuda()
+ if self.num_sample - positive.size(0) >= 0:
+ perm = torch.rand(size=[self.num_local]).cuda()
+ perm[positive] = 2.0
+ index = torch.topk(perm, k=self.num_sample)[1].cuda()
+ index = index.sort()[0].cuda()
+ else:
+ index = positive
+ self.weight_index = index
+
+ labels[index_positive] = torch.searchsorted(index, labels[index_positive])
+
+ return self.weight[self.weight_index]
+
+ def forward(
+ self,
+ local_embeddings: torch.Tensor,
+ local_labels: torch.Tensor,
+ ):
+ """
+ Parameters:
+ ----------
+ local_embeddings: torch.Tensor
+ feature embeddings on each GPU(Rank).
+ local_labels: torch.Tensor
+ labels on each GPU(Rank).
+ Returns:
+ -------
+ loss: torch.Tensor
+ pass
+ """
+ local_labels.squeeze_()
+ local_labels = local_labels.long()
+
+ batch_size = local_embeddings.size(0)
+ if self.last_batch_size == 0:
+ self.last_batch_size = batch_size
+ assert (
+ self.last_batch_size == batch_size
+ ), f"last batch size do not equal current batch size: {self.last_batch_size} vs {batch_size}"
+
+ _gather_embeddings = [torch.zeros((batch_size, self.embedding_size)).cuda() for _ in range(self.world_size)]
+ _gather_labels = [torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)]
+ _list_embeddings = AllGather(local_embeddings, *_gather_embeddings)
+ distributed.all_gather(_gather_labels, local_labels)
+
+ embeddings = torch.cat(_list_embeddings)
+ labels = torch.cat(_gather_labels)
+
+ labels = labels.view(-1, 1)
+ index_positive = (self.class_start <= labels) & (labels < self.class_start + self.num_local)
+ labels[~index_positive] = -1
+ labels[index_positive] -= self.class_start
+
+ if self.sample_rate < 1:
+ weight = self.sample(labels, index_positive)
+ else:
+ weight = self.weight
+
+ with torch.cuda.amp.autocast(self.fp16):
+ norm_embeddings = normalize(embeddings)
+ norm_weight_activated = normalize(weight)
+ logits = linear(norm_embeddings, norm_weight_activated)
+ if self.fp16:
+ logits = logits.float()
+ logits = logits.clamp(-1, 1)
+
+ logits = self.margin_softmax(logits, labels)
+ loss = self.dist_cross_entropy(logits, labels)
+ return loss
+
+
+class DistCrossEntropyFunc(torch.autograd.Function):
+ """
+ CrossEntropy loss is calculated in parallel, allreduce denominator into single gpu and calculate softmax.
+ Implemented of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
+ """
+
+ @staticmethod
+ def forward(ctx, logits: torch.Tensor, label: torch.Tensor):
+ """ """
+ batch_size = logits.size(0)
+ # for numerical stability
+ max_logits, _ = torch.max(logits, dim=1, keepdim=True)
+ # local to global
+ distributed.all_reduce(max_logits, distributed.ReduceOp.MAX)
+ logits.sub_(max_logits)
+ logits.exp_()
+ sum_logits_exp = torch.sum(logits, dim=1, keepdim=True)
+ # local to global
+ distributed.all_reduce(sum_logits_exp, distributed.ReduceOp.SUM)
+ logits.div_(sum_logits_exp)
+ index = torch.where(label != -1)[0]
+ # loss
+ loss = torch.zeros(batch_size, 1, device=logits.device)
+ loss[index] = logits[index].gather(1, label[index])
+ distributed.all_reduce(loss, distributed.ReduceOp.SUM)
+ ctx.save_for_backward(index, logits, label)
+ return loss.clamp_min_(1e-30).log_().mean() * (-1)
+
+ @staticmethod
+ def backward(ctx, loss_gradient):
+ """
+ Args:
+ loss_grad (torch.Tensor): gradient backward by last layer
+ Returns:
+ gradients for each input in forward function
+ `None` gradients for one-hot label
+ """
+ (
+ index,
+ logits,
+ label,
+ ) = ctx.saved_tensors
+ batch_size = logits.size(0)
+ one_hot = torch.zeros(size=[index.size(0), logits.size(1)], device=logits.device)
+ one_hot.scatter_(1, label[index], 1)
+ logits[index] -= one_hot
+ logits.div_(batch_size)
+ return logits * loss_gradient.item(), None
+
+
+class DistCrossEntropy(torch.nn.Module):
+ def __init__(self):
+ super(DistCrossEntropy, self).__init__()
+
+ def forward(self, logit_part, label_part):
+ return DistCrossEntropyFunc.apply(logit_part, label_part)
+
+
+class AllGatherFunc(torch.autograd.Function):
+ """AllGather op with gradient backward"""
+
+ @staticmethod
+ def forward(ctx, tensor, *gather_list):
+ gather_list = list(gather_list)
+ distributed.all_gather(gather_list, tensor)
+ return tuple(gather_list)
+
+ @staticmethod
+ def backward(ctx, *grads):
+ grad_list = list(grads)
+ rank = distributed.get_rank()
+ grad_out = grad_list[rank]
+
+ dist_ops = [
+ distributed.reduce(grad_out, rank, distributed.ReduceOp.SUM, async_op=True)
+ if i == rank
+ else distributed.reduce(grad_list[i], i, distributed.ReduceOp.SUM, async_op=True)
+ for i in range(distributed.get_world_size())
+ ]
+ for _op in dist_ops:
+ _op.wait()
+
+ grad_out *= len(grad_list) # cooperate with distributed loss function
+ return (grad_out, *[None for _ in range(len(grad_list))])
+
+
+AllGather = AllGatherFunc.apply
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/requirement.txt b/Deep3DFaceRecon_pytorch/models/arcface_torch/requirement.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f1a431ef9c39b258b676411f1081ed9006a8b817
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/requirement.txt
@@ -0,0 +1,6 @@
+tensorboard
+easydict
+mxnet
+onnx
+sklearn
+opencv-python
\ No newline at end of file
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/run.sh b/Deep3DFaceRecon_pytorch/models/arcface_torch/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6eacdf8e814d7bd68650c7eda8f72687ee74db16
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/run.sh
@@ -0,0 +1 @@
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 train_v2.py $@
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/scripts/shuffle_rec.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/scripts/shuffle_rec.py
new file mode 100644
index 0000000000000000000000000000000000000000..1607fb2db48b9b32f4fa16c6ad97d15582820b2a
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/scripts/shuffle_rec.py
@@ -0,0 +1,81 @@
+import argparse
+import multiprocessing
+import os
+import time
+
+import mxnet as mx
+import numpy as np
+
+
+def read_worker(args, q_in):
+ path_imgidx = os.path.join(args.input, "train.idx")
+ path_imgrec = os.path.join(args.input, "train.rec")
+ imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, "r")
+
+ s = imgrec.read_idx(0)
+ header, _ = mx.recordio.unpack(s)
+ assert header.flag > 0
+
+ imgidx = np.array(range(1, int(header.label[0])))
+ np.random.shuffle(imgidx)
+
+ for idx in imgidx:
+ item = imgrec.read_idx(idx)
+ q_in.put(item)
+
+ q_in.put(None)
+ imgrec.close()
+
+
+def write_worker(args, q_out):
+ pre_time = time.time()
+
+ if args.input[-1] == "/":
+ args.input = args.input[:-1]
+ dirname = os.path.dirname(args.input)
+ basename = os.path.basename(args.input)
+ output = os.path.join(dirname, f"shuffled_{basename}")
+ os.makedirs(output, exist_ok=True)
+
+ path_imgidx = os.path.join(output, "train.idx")
+ path_imgrec = os.path.join(output, "train.rec")
+ save_record = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, "w")
+ more = True
+ count = 0
+ while more:
+ deq = q_out.get()
+ if deq is None:
+ more = False
+ else:
+ header, jpeg = mx.recordio.unpack(deq)
+ # TODO it is currently not fully developed
+ if isinstance(header.label, float):
+ label = header.label
+ else:
+ label = header.label[0]
+
+ header = mx.recordio.IRHeader(flag=header.flag, label=label, id=header.id, id2=header.id2)
+ save_record.write_idx(count, mx.recordio.pack(header, jpeg))
+ count += 1
+ if count % 10000 == 0:
+ cur_time = time.time()
+ print("save time:", cur_time - pre_time, " count:", count)
+ pre_time = cur_time
+ print(count)
+ save_record.close()
+
+
+def main(args):
+ queue = multiprocessing.Queue(10240)
+ read_process = multiprocessing.Process(target=read_worker, args=(args, queue))
+ read_process.daemon = True
+ read_process.start()
+ write_process = multiprocessing.Process(target=write_worker, args=(args, queue))
+ write_process.start()
+ write_process.join()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("input", help="path to source rec.")
+ main(parser.parse_args())
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/torch2onnx.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/torch2onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..23c2bb9e85c9bc5dc0b90842ad9c782d5e7cde79
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/torch2onnx.py
@@ -0,0 +1,56 @@
+import numpy as np
+import onnx
+import torch
+
+
+def convert_onnx(net, path_module, output, opset=11, simplify=False):
+ assert isinstance(net, torch.nn.Module)
+ img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
+ img = img.astype(np.float)
+ img = (img / 255.0 - 0.5) / 0.5 # torch style norm
+ img = img.transpose((2, 0, 1))
+ img = torch.from_numpy(img).unsqueeze(0).float()
+
+ weight = torch.load(path_module)
+ net.load_state_dict(weight, strict=True)
+ net.eval()
+ torch.onnx.export(
+ net, img, output, input_names=["data"], keep_initializers_as_inputs=False, verbose=False, opset_version=opset
+ )
+ model = onnx.load(output)
+ graph = model.graph
+ graph.input[0].type.tensor_type.shape.dim[0].dim_param = "None"
+ if simplify:
+ from onnxsim import simplify
+
+ model, check = simplify(model)
+ assert check, "Simplified ONNX model could not be validated"
+ onnx.save(model, output)
+
+
+if __name__ == "__main__":
+ import os
+ import argparse
+ from backbones import get_model
+
+ parser = argparse.ArgumentParser(description="ArcFace PyTorch to onnx")
+ parser.add_argument("input", type=str, help="input backbone.pth file or path")
+ parser.add_argument("--output", type=str, default=None, help="output onnx path")
+ parser.add_argument("--network", type=str, default=None, help="backbone network")
+ parser.add_argument("--simplify", type=bool, default=False, help="onnx simplify")
+ args = parser.parse_args()
+ input_file = args.input
+ if os.path.isdir(input_file):
+ input_file = os.path.join(input_file, "model.pt")
+ assert os.path.exists(input_file)
+ # model_name = os.path.basename(os.path.dirname(input_file)).lower()
+ # params = model_name.split("_")
+ # if len(params) >= 3 and params[1] in ('arcface', 'cosface'):
+ # if args.network is None:
+ # args.network = params[2]
+ assert args.network is not None
+ print(args)
+ backbone_onnx = get_model(args.network, dropout=0.0, fp16=False, num_features=512)
+ if args.output is None:
+ args.output = os.path.join(os.path.dirname(args.input), "model.onnx")
+ convert_onnx(backbone_onnx, input_file, args.output, simplify=args.simplify)
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/train.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/train.py
new file mode 100755
index 0000000000000000000000000000000000000000..3905bb0f90bb3806cd0698a322617a0d47c61390
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/train.py
@@ -0,0 +1,253 @@
+import argparse
+import logging
+import os
+from datetime import datetime
+
+import numpy as np
+import torch
+from backbones import get_model
+from dataset import get_dataloader
+from losses import CombinedMarginLoss
+from lr_scheduler import PolyScheduler
+from partial_fc import PartialFC
+from partial_fc import PartialFCAdamW
+from torch import distributed
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+from utils.utils_callbacks import CallBackLogging
+from utils.utils_callbacks import CallBackVerification
+from utils.utils_config import get_config
+from utils.utils_distributed_sampler import setup_seed
+from utils.utils_logging import AverageMeter
+from utils.utils_logging import init_logging
+
+assert (
+ torch.__version__ >= "1.12.0"
+), "In order to enjoy the features of the new torch, \
+we have upgraded the torch to 1.12.0. torch before than 1.12.0 may not work in the future."
+
+try:
+ rank = int(os.environ["RANK"])
+ local_rank = int(os.environ["LOCAL_RANK"])
+ world_size = int(os.environ["WORLD_SIZE"])
+ distributed.init_process_group("nccl")
+except KeyError:
+ rank = 0
+ local_rank = 0
+ world_size = 1
+ distributed.init_process_group(
+ backend="nccl",
+ init_method="tcp://127.0.0.1:12584",
+ rank=rank,
+ world_size=world_size,
+ )
+
+
+def main(args):
+
+ # get config
+ cfg = get_config(args.config)
+ # global control random seed
+ setup_seed(seed=cfg.seed, cuda_deterministic=False)
+
+ torch.cuda.set_device(local_rank)
+
+ os.makedirs(cfg.output, exist_ok=True)
+ init_logging(rank, cfg.output)
+
+ summary_writer = SummaryWriter(log_dir=os.path.join(cfg.output, "tensorboard")) if rank == 0 else None
+
+ wandb_logger = None
+ if cfg.using_wandb:
+ import wandb
+
+ # Sign in to wandb
+ try:
+ wandb.login(key=cfg.wandb_key)
+ except Exception as e:
+ print("WandB Key must be provided in config file (base.py).")
+ print(f"Config Error: {e}")
+ # Initialize wandb
+ run_name = datetime.now().strftime("%y%m%d_%H%M") + f"_GPU{rank}"
+ run_name = run_name if cfg.suffix_run_name is None else run_name + f"_{cfg.suffix_run_name}"
+ try:
+ wandb_logger = (
+ wandb.init(
+ entity=cfg.wandb_entity,
+ project=cfg.wandb_project,
+ sync_tensorboard=True,
+ resume=cfg.wandb_resume,
+ name=run_name,
+ notes=cfg.notes,
+ )
+ if rank == 0 or cfg.wandb_log_all
+ else None
+ )
+ if wandb_logger:
+ wandb_logger.config.update(cfg)
+ except Exception as e:
+ print("WandB Data (Entity and Project name) must be provided in config file (base.py).")
+ print(f"Config Error: {e}")
+
+ train_loader = get_dataloader(cfg.rec, local_rank, cfg.batch_size, cfg.dali, cfg.seed, cfg.num_workers)
+
+ backbone = get_model(cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).cuda()
+
+ backbone = torch.nn.parallel.DistributedDataParallel(
+ module=backbone, broadcast_buffers=False, device_ids=[local_rank], bucket_cap_mb=16, find_unused_parameters=True
+ )
+
+ backbone.train()
+ # FIXME using gradient checkpoint if there are some unused parameters will cause error
+ backbone._set_static_graph()
+
+ margin_loss = CombinedMarginLoss(
+ 64, cfg.margin_list[0], cfg.margin_list[1], cfg.margin_list[2], cfg.interclass_filtering_threshold
+ )
+
+ if cfg.optimizer == "sgd":
+ module_partial_fc = PartialFC(margin_loss, cfg.embedding_size, cfg.num_classes, cfg.sample_rate, cfg.fp16)
+ module_partial_fc.train().cuda()
+ # TODO the params of partial fc must be last in the params list
+ opt = torch.optim.SGD(
+ params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}],
+ lr=cfg.lr,
+ momentum=0.9,
+ weight_decay=cfg.weight_decay,
+ )
+
+ elif cfg.optimizer == "adamw":
+ module_partial_fc = PartialFCAdamW(margin_loss, cfg.embedding_size, cfg.num_classes, cfg.sample_rate, cfg.fp16)
+ module_partial_fc.train().cuda()
+ opt = torch.optim.AdamW(
+ params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}],
+ lr=cfg.lr,
+ weight_decay=cfg.weight_decay,
+ )
+ else:
+ raise
+
+ cfg.total_batch_size = cfg.batch_size * world_size
+ cfg.warmup_step = cfg.num_image // cfg.total_batch_size * cfg.warmup_epoch
+ cfg.total_step = cfg.num_image // cfg.total_batch_size * cfg.num_epoch
+
+ lr_scheduler = PolyScheduler(
+ optimizer=opt, base_lr=cfg.lr, max_steps=cfg.total_step, warmup_steps=cfg.warmup_step, last_epoch=-1
+ )
+
+ start_epoch = 0
+ global_step = 0
+ if cfg.resume:
+ dict_checkpoint = torch.load(os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt"))
+ start_epoch = dict_checkpoint["epoch"]
+ global_step = dict_checkpoint["global_step"]
+ backbone.module.load_state_dict(dict_checkpoint["state_dict_backbone"])
+ module_partial_fc.load_state_dict(dict_checkpoint["state_dict_softmax_fc"])
+ opt.load_state_dict(dict_checkpoint["state_optimizer"])
+ lr_scheduler.load_state_dict(dict_checkpoint["state_lr_scheduler"])
+ del dict_checkpoint
+
+ for key, value in cfg.items():
+ num_space = 25 - len(key)
+ logging.info(": " + key + " " * num_space + str(value))
+
+ callback_verification = CallBackVerification(
+ val_targets=cfg.val_targets, rec_prefix=cfg.rec, summary_writer=summary_writer, wandb_logger=wandb_logger
+ )
+ callback_logging = CallBackLogging(
+ frequent=cfg.frequent,
+ total_step=cfg.total_step,
+ batch_size=cfg.batch_size,
+ start_step=global_step,
+ writer=summary_writer,
+ )
+
+ loss_am = AverageMeter()
+ amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100)
+
+ for epoch in range(start_epoch, cfg.num_epoch):
+
+ if isinstance(train_loader, DataLoader):
+ train_loader.sampler.set_epoch(epoch)
+ for _, (img, local_labels) in enumerate(train_loader):
+ global_step += 1
+ local_embeddings = backbone(img)
+ loss: torch.Tensor = module_partial_fc(local_embeddings, local_labels, opt)
+
+ if cfg.fp16:
+ amp.scale(loss).backward()
+ amp.unscale_(opt)
+ torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
+ amp.step(opt)
+ amp.update()
+ else:
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
+ opt.step()
+
+ opt.zero_grad()
+ lr_scheduler.step()
+
+ with torch.no_grad():
+ if wandb_logger:
+ wandb_logger.log(
+ {
+ "Loss/Step Loss": loss.item(),
+ "Loss/Train Loss": loss_am.avg,
+ "Process/Step": global_step,
+ "Process/Epoch": epoch,
+ }
+ )
+
+ loss_am.update(loss.item(), 1)
+ callback_logging(global_step, loss_am, epoch, cfg.fp16, lr_scheduler.get_last_lr()[0], amp)
+
+ if global_step % cfg.verbose == 0 and global_step > 0:
+ callback_verification(global_step, backbone)
+
+ if cfg.save_all_states:
+ checkpoint = {
+ "epoch": epoch + 1,
+ "global_step": global_step,
+ "state_dict_backbone": backbone.module.state_dict(),
+ "state_dict_softmax_fc": module_partial_fc.state_dict(),
+ "state_optimizer": opt.state_dict(),
+ "state_lr_scheduler": lr_scheduler.state_dict(),
+ }
+ torch.save(checkpoint, os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt"))
+
+ if rank == 0:
+ path_module = os.path.join(cfg.output, "model.pt")
+ torch.save(backbone.module.state_dict(), path_module)
+
+ if wandb_logger and cfg.save_artifacts:
+ artifact_name = f"{run_name}_E{epoch}"
+ model = wandb.Artifact(artifact_name, type="model")
+ model.add_file(path_module)
+ wandb_logger.log_artifact(model)
+
+ if cfg.dali:
+ train_loader.reset()
+
+ if rank == 0:
+ path_module = os.path.join(cfg.output, "model.pt")
+ torch.save(backbone.module.state_dict(), path_module)
+
+ from torch2onnx import convert_onnx
+
+ convert_onnx(backbone.module.cpu().eval(), path_module, os.path.join(cfg.output, "model.onnx"))
+
+ if wandb_logger and cfg.save_artifacts:
+ artifact_name = f"{run_name}_Final"
+ model = wandb.Artifact(artifact_name, type="model")
+ model.add_file(path_module)
+ wandb_logger.log_artifact(model)
+
+ distributed.destroy_process_group()
+
+
+if __name__ == "__main__":
+ torch.backends.cudnn.benchmark = True
+ parser = argparse.ArgumentParser(description="Distributed Arcface Training in Pytorch")
+ parser.add_argument("config", type=str, help="py config file")
+ main(parser.parse_args())
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/train_v2.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/train_v2.py
new file mode 100755
index 0000000000000000000000000000000000000000..ba3c15e6a1615f28daaab1ad225f7b61b27bdffc
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/train_v2.py
@@ -0,0 +1,248 @@
+import argparse
+import logging
+import os
+from datetime import datetime
+
+import numpy as np
+import torch
+from backbones import get_model
+from dataset import get_dataloader
+from losses import CombinedMarginLoss
+from lr_scheduler import PolyScheduler
+from partial_fc_v2 import PartialFC_V2
+from torch import distributed
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+from utils.utils_callbacks import CallBackLogging
+from utils.utils_callbacks import CallBackVerification
+from utils.utils_config import get_config
+from utils.utils_distributed_sampler import setup_seed
+from utils.utils_logging import AverageMeter
+from utils.utils_logging import init_logging
+
+assert (
+ torch.__version__ >= "1.12.0"
+), "In order to enjoy the features of the new torch, \
+we have upgraded the torch to 1.12.0. torch before than 1.12.0 may not work in the future."
+
+try:
+ rank = int(os.environ["RANK"])
+ local_rank = int(os.environ["LOCAL_RANK"])
+ world_size = int(os.environ["WORLD_SIZE"])
+ distributed.init_process_group("nccl")
+except KeyError:
+ rank = 0
+ local_rank = 0
+ world_size = 1
+ distributed.init_process_group(
+ backend="nccl",
+ init_method="tcp://127.0.0.1:12584",
+ rank=rank,
+ world_size=world_size,
+ )
+
+
+def main(args):
+
+ # get config
+ cfg = get_config(args.config)
+ # global control random seed
+ setup_seed(seed=cfg.seed, cuda_deterministic=False)
+
+ torch.cuda.set_device(local_rank)
+
+ os.makedirs(cfg.output, exist_ok=True)
+ init_logging(rank, cfg.output)
+
+ summary_writer = SummaryWriter(log_dir=os.path.join(cfg.output, "tensorboard")) if rank == 0 else None
+
+ wandb_logger = None
+ if cfg.using_wandb:
+ import wandb
+
+ # Sign in to wandb
+ try:
+ wandb.login(key=cfg.wandb_key)
+ except Exception as e:
+ print("WandB Key must be provided in config file (base.py).")
+ print(f"Config Error: {e}")
+ # Initialize wandb
+ run_name = datetime.now().strftime("%y%m%d_%H%M") + f"_GPU{rank}"
+ run_name = run_name if cfg.suffix_run_name is None else run_name + f"_{cfg.suffix_run_name}"
+ try:
+ wandb_logger = (
+ wandb.init(
+ entity=cfg.wandb_entity,
+ project=cfg.wandb_project,
+ sync_tensorboard=True,
+ resume=cfg.wandb_resume,
+ name=run_name,
+ notes=cfg.notes,
+ )
+ if rank == 0 or cfg.wandb_log_all
+ else None
+ )
+ if wandb_logger:
+ wandb_logger.config.update(cfg)
+ except Exception as e:
+ print("WandB Data (Entity and Project name) must be provided in config file (base.py).")
+ print(f"Config Error: {e}")
+
+ train_loader = get_dataloader(cfg.rec, local_rank, cfg.batch_size, cfg.dali, cfg.seed, cfg.num_workers)
+
+ backbone = get_model(cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).cuda()
+
+ backbone = torch.nn.parallel.DistributedDataParallel(
+ module=backbone, broadcast_buffers=False, device_ids=[local_rank], bucket_cap_mb=16, find_unused_parameters=True
+ )
+
+ backbone.train()
+ # FIXME using gradient checkpoint if there are some unused parameters will cause error
+ backbone._set_static_graph()
+
+ margin_loss = CombinedMarginLoss(
+ 64, cfg.margin_list[0], cfg.margin_list[1], cfg.margin_list[2], cfg.interclass_filtering_threshold
+ )
+
+ if cfg.optimizer == "sgd":
+ module_partial_fc = PartialFC_V2(margin_loss, cfg.embedding_size, cfg.num_classes, cfg.sample_rate, cfg.fp16)
+ module_partial_fc.train().cuda()
+ # TODO the params of partial fc must be last in the params list
+ opt = torch.optim.SGD(
+ params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}],
+ lr=cfg.lr,
+ momentum=0.9,
+ weight_decay=cfg.weight_decay,
+ )
+
+ elif cfg.optimizer == "adamw":
+ module_partial_fc = PartialFC_V2(margin_loss, cfg.embedding_size, cfg.num_classes, cfg.sample_rate, cfg.fp16)
+ module_partial_fc.train().cuda()
+ opt = torch.optim.AdamW(
+ params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}],
+ lr=cfg.lr,
+ weight_decay=cfg.weight_decay,
+ )
+ else:
+ raise
+
+ cfg.total_batch_size = cfg.batch_size * world_size
+ cfg.warmup_step = cfg.num_image // cfg.total_batch_size * cfg.warmup_epoch
+ cfg.total_step = cfg.num_image // cfg.total_batch_size * cfg.num_epoch
+
+ lr_scheduler = PolyScheduler(
+ optimizer=opt, base_lr=cfg.lr, max_steps=cfg.total_step, warmup_steps=cfg.warmup_step, last_epoch=-1
+ )
+
+ start_epoch = 0
+ global_step = 0
+ if cfg.resume:
+ dict_checkpoint = torch.load(os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt"))
+ start_epoch = dict_checkpoint["epoch"]
+ global_step = dict_checkpoint["global_step"]
+ backbone.module.load_state_dict(dict_checkpoint["state_dict_backbone"])
+ module_partial_fc.load_state_dict(dict_checkpoint["state_dict_softmax_fc"])
+ opt.load_state_dict(dict_checkpoint["state_optimizer"])
+ lr_scheduler.load_state_dict(dict_checkpoint["state_lr_scheduler"])
+ del dict_checkpoint
+
+ for key, value in cfg.items():
+ num_space = 25 - len(key)
+ logging.info(": " + key + " " * num_space + str(value))
+
+ callback_verification = CallBackVerification(
+ val_targets=cfg.val_targets, rec_prefix=cfg.rec, summary_writer=summary_writer, wandb_logger=wandb_logger
+ )
+ callback_logging = CallBackLogging(
+ frequent=cfg.frequent,
+ total_step=cfg.total_step,
+ batch_size=cfg.batch_size,
+ start_step=global_step,
+ writer=summary_writer,
+ )
+
+ loss_am = AverageMeter()
+ amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100)
+
+ for epoch in range(start_epoch, cfg.num_epoch):
+
+ if isinstance(train_loader, DataLoader):
+ train_loader.sampler.set_epoch(epoch)
+ for _, (img, local_labels) in enumerate(train_loader):
+ global_step += 1
+ local_embeddings = backbone(img)
+ loss: torch.Tensor = module_partial_fc(local_embeddings, local_labels)
+
+ if cfg.fp16:
+ amp.scale(loss).backward()
+ if global_step % cfg.gradient_acc == 0:
+ amp.unscale_(opt)
+ torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
+ amp.step(opt)
+ amp.update()
+ opt.zero_grad()
+ else:
+ loss.backward()
+ if global_step % cfg.gradient_acc == 0:
+ torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
+ opt.step()
+ opt.zero_grad()
+ lr_scheduler.step()
+
+ with torch.no_grad():
+ if wandb_logger:
+ wandb_logger.log(
+ {
+ "Loss/Step Loss": loss.item(),
+ "Loss/Train Loss": loss_am.avg,
+ "Process/Step": global_step,
+ "Process/Epoch": epoch,
+ }
+ )
+
+ loss_am.update(loss.item(), 1)
+ callback_logging(global_step, loss_am, epoch, cfg.fp16, lr_scheduler.get_last_lr()[0], amp)
+
+ if global_step % cfg.verbose == 0 and global_step > 0:
+ callback_verification(global_step, backbone)
+
+ if cfg.save_all_states:
+ checkpoint = {
+ "epoch": epoch + 1,
+ "global_step": global_step,
+ "state_dict_backbone": backbone.module.state_dict(),
+ "state_dict_softmax_fc": module_partial_fc.state_dict(),
+ "state_optimizer": opt.state_dict(),
+ "state_lr_scheduler": lr_scheduler.state_dict(),
+ }
+ torch.save(checkpoint, os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt"))
+
+ if rank == 0:
+ path_module = os.path.join(cfg.output, "model.pt")
+ torch.save(backbone.module.state_dict(), path_module)
+
+ if wandb_logger and cfg.save_artifacts:
+ artifact_name = f"{run_name}_E{epoch}"
+ model = wandb.Artifact(artifact_name, type="model")
+ model.add_file(path_module)
+ wandb_logger.log_artifact(model)
+
+ if cfg.dali:
+ train_loader.reset()
+
+ if rank == 0:
+ path_module = os.path.join(cfg.output, "model.pt")
+ torch.save(backbone.module.state_dict(), path_module)
+
+ if wandb_logger and cfg.save_artifacts:
+ artifact_name = f"{run_name}_Final"
+ model = wandb.Artifact(artifact_name, type="model")
+ model.add_file(path_module)
+ wandb_logger.log_artifact(model)
+
+
+if __name__ == "__main__":
+ torch.backends.cudnn.benchmark = True
+ parser = argparse.ArgumentParser(description="Distributed Arcface Training in Pytorch")
+ parser.add_argument("config", type=str, help="py config file")
+ main(parser.parse_args())
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/utils/__init__.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/utils/plot.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/utils/plot.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e1429e77d67c32ce9f6c4495c75608941bbcebc
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/utils/plot.py
@@ -0,0 +1,65 @@
+import os
+import sys
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
+from prettytable import PrettyTable
+from sklearn.metrics import auc
+from sklearn.metrics import roc_curve
+
+with open(sys.argv[1], "r") as f:
+ files = f.readlines()
+
+files = [x.strip() for x in files]
+image_path = "/train_tmp/IJB_release/IJBC"
+
+
+def read_template_pair_list(path):
+ pairs = pd.read_csv(path, sep=" ", header=None).values
+ t1 = pairs[:, 0].astype(np.int)
+ t2 = pairs[:, 1].astype(np.int)
+ label = pairs[:, 2].astype(np.int)
+ return t1, t2, label
+
+
+p1, p2, label = read_template_pair_list(os.path.join("%s/meta" % image_path, "%s_template_pair_label.txt" % "ijbc"))
+
+methods = []
+scores = []
+for file in files:
+ methods.append(file)
+ scores.append(np.load(file))
+
+methods = np.array(methods)
+scores = dict(zip(methods, scores))
+colours = dict(zip(methods, sample_colours_from_colourmap(methods.shape[0], "Set2")))
+x_labels = [10**-6, 10**-5, 10**-4, 10**-3, 10**-2, 10**-1]
+tpr_fpr_table = PrettyTable(["Methods"] + [str(x) for x in x_labels])
+fig = plt.figure()
+for method in methods:
+ fpr, tpr, _ = roc_curve(label, scores[method])
+ roc_auc = auc(fpr, tpr)
+ fpr = np.flipud(fpr)
+ tpr = np.flipud(tpr) # select largest tpr at same fpr
+ plt.plot(
+ fpr, tpr, color=colours[method], lw=1, label=("[%s (AUC = %0.4f %%)]" % (method.split("-")[-1], roc_auc * 100))
+ )
+ tpr_fpr_row = []
+ tpr_fpr_row.append(method)
+ for fpr_iter in np.arange(len(x_labels)):
+ _, min_index = min(list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
+ tpr_fpr_row.append("%.2f" % (tpr[min_index] * 100))
+ tpr_fpr_table.add_row(tpr_fpr_row)
+plt.xlim([10**-6, 0.1])
+plt.ylim([0.3, 1.0])
+plt.grid(linestyle="--", linewidth=1)
+plt.xticks(x_labels)
+plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
+plt.xscale("log")
+plt.xlabel("False Positive Rate")
+plt.ylabel("True Positive Rate")
+plt.title("ROC on IJB")
+plt.legend(loc="lower right")
+print(tpr_fpr_table)
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/utils/utils_callbacks.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/utils/utils_callbacks.py
new file mode 100755
index 0000000000000000000000000000000000000000..6afa461dd3a163628f71499da66fc032b272e969
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/utils/utils_callbacks.py
@@ -0,0 +1,141 @@
+import logging
+import os
+import time
+from typing import List
+
+import torch
+from eval import verification
+from torch import distributed
+from torch.utils.tensorboard import SummaryWriter
+from utils.utils_logging import AverageMeter
+
+
+class CallBackVerification(object):
+ def __init__(self, val_targets, rec_prefix, summary_writer=None, image_size=(112, 112), wandb_logger=None):
+ self.rank: int = distributed.get_rank()
+ self.highest_acc: float = 0.0
+ self.highest_acc_list: List[float] = [0.0] * len(val_targets)
+ self.ver_list: List[object] = []
+ self.ver_name_list: List[str] = []
+ if self.rank is 0:
+ self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size)
+
+ self.summary_writer = summary_writer
+ self.wandb_logger = wandb_logger
+
+ def ver_test(self, backbone: torch.nn.Module, global_step: int):
+ results = []
+ for i in range(len(self.ver_list)):
+ acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(self.ver_list[i], backbone, 10, 10)
+ logging.info("[%s][%d]XNorm: %f" % (self.ver_name_list[i], global_step, xnorm))
+ logging.info("[%s][%d]Accuracy-Flip: %1.5f+-%1.5f" % (self.ver_name_list[i], global_step, acc2, std2))
+
+ self.summary_writer: SummaryWriter
+ self.summary_writer.add_scalar(
+ tag=self.ver_name_list[i],
+ scalar_value=acc2,
+ global_step=global_step,
+ )
+ if self.wandb_logger:
+ import wandb
+
+ self.wandb_logger.log(
+ {
+ f"Acc/val-Acc1 {self.ver_name_list[i]}": acc1,
+ f"Acc/val-Acc2 {self.ver_name_list[i]}": acc2,
+ # f'Acc/val-std1 {self.ver_name_list[i]}': std1,
+ # f'Acc/val-std2 {self.ver_name_list[i]}': acc2,
+ }
+ )
+
+ if acc2 > self.highest_acc_list[i]:
+ self.highest_acc_list[i] = acc2
+ logging.info(
+ "[%s][%d]Accuracy-Highest: %1.5f" % (self.ver_name_list[i], global_step, self.highest_acc_list[i])
+ )
+ results.append(acc2)
+
+ def init_dataset(self, val_targets, data_dir, image_size):
+ for name in val_targets:
+ path = os.path.join(data_dir, name + ".bin")
+ if os.path.exists(path):
+ data_set = verification.load_bin(path, image_size)
+ self.ver_list.append(data_set)
+ self.ver_name_list.append(name)
+
+ def __call__(self, num_update, backbone: torch.nn.Module):
+ if self.rank is 0 and num_update > 0:
+ backbone.eval()
+ self.ver_test(backbone, num_update)
+ backbone.train()
+
+
+class CallBackLogging(object):
+ def __init__(self, frequent, total_step, batch_size, start_step=0, writer=None):
+ self.frequent: int = frequent
+ self.rank: int = distributed.get_rank()
+ self.world_size: int = distributed.get_world_size()
+ self.time_start = time.time()
+ self.total_step: int = total_step
+ self.start_step: int = start_step
+ self.batch_size: int = batch_size
+ self.writer = writer
+
+ self.init = False
+ self.tic = 0
+
+ def __call__(
+ self,
+ global_step: int,
+ loss: AverageMeter,
+ epoch: int,
+ fp16: bool,
+ learning_rate: float,
+ grad_scaler: torch.cuda.amp.GradScaler,
+ ):
+ if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0:
+ if self.init:
+ try:
+ speed: float = self.frequent * self.batch_size / (time.time() - self.tic)
+ speed_total = speed * self.world_size
+ except ZeroDivisionError:
+ speed_total = float("inf")
+
+ # time_now = (time.time() - self.time_start) / 3600
+ # time_total = time_now / ((global_step + 1) / self.total_step)
+ # time_for_end = time_total - time_now
+ time_now = time.time()
+ time_sec = int(time_now - self.time_start)
+ time_sec_avg = time_sec / (global_step - self.start_step + 1)
+ eta_sec = time_sec_avg * (self.total_step - global_step - 1)
+ time_for_end = eta_sec / 3600
+ if self.writer is not None:
+ self.writer.add_scalar("time_for_end", time_for_end, global_step)
+ self.writer.add_scalar("learning_rate", learning_rate, global_step)
+ self.writer.add_scalar("loss", loss.avg, global_step)
+ if fp16:
+ msg = (
+ "Speed %.2f samples/sec Loss %.4f LearningRate %.6f Epoch: %d Global Step: %d "
+ "Fp16 Grad Scale: %2.f Required: %1.f hours"
+ % (
+ speed_total,
+ loss.avg,
+ learning_rate,
+ epoch,
+ global_step,
+ grad_scaler.get_scale(),
+ time_for_end,
+ )
+ )
+ else:
+ msg = (
+ "Speed %.2f samples/sec Loss %.4f LearningRate %.6f Epoch: %d Global Step: %d "
+ "Required: %1.f hours"
+ % (speed_total, loss.avg, learning_rate, epoch, global_step, time_for_end)
+ )
+ logging.info(msg)
+ loss.reset()
+ self.tic = time.time()
+ else:
+ self.init = True
+ self.tic = time.time()
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/utils/utils_config.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/utils/utils_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..140625ccfbc1b4b8d71470f50da7d4f88803cf11
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/utils/utils_config.py
@@ -0,0 +1,16 @@
+import importlib
+import os.path as osp
+
+
+def get_config(config_file):
+ assert config_file.startswith("configs/"), "config file setting must start with configs/"
+ temp_config_name = osp.basename(config_file)
+ temp_module_name = osp.splitext(temp_config_name)[0]
+ config = importlib.import_module("configs.base")
+ cfg = config.config
+ config = importlib.import_module("configs.%s" % temp_module_name)
+ job_cfg = config.config
+ cfg.update(job_cfg)
+ if cfg.output is None:
+ cfg.output = osp.join("work_dirs", temp_module_name)
+ return cfg
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/utils/utils_distributed_sampler.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/utils/utils_distributed_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7e57275fa17a0a9dbf27fd0eb941dd0fec1823f
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/utils/utils_distributed_sampler.py
@@ -0,0 +1,124 @@
+import math
+import os
+import random
+
+import numpy as np
+import torch
+import torch.distributed as dist
+from torch.utils.data import DistributedSampler as _DistributedSampler
+
+
+def setup_seed(seed, cuda_deterministic=True):
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ os.environ["PYTHONHASHSEED"] = str(seed)
+ if cuda_deterministic: # slower, more reproducible
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ else: # faster, less reproducible
+ torch.backends.cudnn.deterministic = False
+ torch.backends.cudnn.benchmark = True
+
+
+def worker_init_fn(worker_id, num_workers, rank, seed):
+ # The seed of each worker equals to
+ # num_worker * rank + worker_id + user_seed
+ worker_seed = num_workers * rank + worker_id + seed
+ np.random.seed(worker_seed)
+ random.seed(worker_seed)
+ torch.manual_seed(worker_seed)
+
+
+def get_dist_info():
+ if dist.is_available() and dist.is_initialized():
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ else:
+ rank = 0
+ world_size = 1
+
+ return rank, world_size
+
+
+def sync_random_seed(seed=None, device="cuda"):
+ """Make sure different ranks share the same seed.
+ All workers must call this function, otherwise it will deadlock.
+ This method is generally used in `DistributedSampler`,
+ because the seed should be identical across all processes
+ in the distributed group.
+ In distributed sampling, different ranks should sample non-overlapped
+ data in the dataset. Therefore, this function is used to make sure that
+ each rank shuffles the data indices in the same order based
+ on the same seed. Then different ranks could use different indices
+ to select non-overlapped data from the same data list.
+ Args:
+ seed (int, Optional): The seed. Default to None.
+ device (str): The device where the seed will be put on.
+ Default to 'cuda'.
+ Returns:
+ int: Seed to be used.
+ """
+ if seed is None:
+ seed = np.random.randint(2**31)
+ assert isinstance(seed, int)
+
+ rank, world_size = get_dist_info()
+
+ if world_size == 1:
+ return seed
+
+ if rank == 0:
+ random_num = torch.tensor(seed, dtype=torch.int32, device=device)
+ else:
+ random_num = torch.tensor(0, dtype=torch.int32, device=device)
+
+ dist.broadcast(random_num, src=0)
+
+ return random_num.item()
+
+
+class DistributedSampler(_DistributedSampler):
+ def __init__(
+ self,
+ dataset,
+ num_replicas=None, # world_size
+ rank=None, # local_rank
+ shuffle=True,
+ seed=0,
+ ):
+
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
+
+ # In distributed sampling, different ranks should sample
+ # non-overlapped data in the dataset. Therefore, this function
+ # is used to make sure that each rank shuffles the data indices
+ # in the same order based on the same seed. Then different ranks
+ # could use different indices to select non-overlapped data from the
+ # same data list.
+ self.seed = sync_random_seed(seed)
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ if self.shuffle:
+ g = torch.Generator()
+ # When :attr:`shuffle=True`, this ensures all replicas
+ # use a different random ordering for each epoch.
+ # Otherwise, the next iteration of this sampler will
+ # yield the same ordering.
+ g.manual_seed(self.epoch + self.seed)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
+ else:
+ indices = torch.arange(len(self.dataset)).tolist()
+
+ # add extra samples to make it evenly divisible
+ # in case that indices is shorter than half of total_size
+ indices = (indices * math.ceil(self.total_size / len(indices)))[: self.total_size]
+ assert len(indices) == self.total_size
+
+ # subsample
+ indices = indices[self.rank : self.total_size : self.num_replicas]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
diff --git a/Deep3DFaceRecon_pytorch/models/arcface_torch/utils/utils_logging.py b/Deep3DFaceRecon_pytorch/models/arcface_torch/utils/utils_logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..823771b7d7c45fd30fe7d5284cb52ee6ad17c834
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/arcface_torch/utils/utils_logging.py
@@ -0,0 +1,40 @@
+import logging
+import os
+import sys
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+
+ def __init__(self):
+ self.val = None
+ self.avg = None
+ self.sum = None
+ self.count = None
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def init_logging(rank, models_root):
+ if rank == 0:
+ log_root = logging.getLogger()
+ log_root.setLevel(logging.INFO)
+ formatter = logging.Formatter("Training: %(asctime)s-%(message)s")
+ handler_file = logging.FileHandler(os.path.join(models_root, "training.log"))
+ handler_stream = logging.StreamHandler(sys.stdout)
+ handler_file.setFormatter(formatter)
+ handler_stream.setFormatter(formatter)
+ log_root.addHandler(handler_file)
+ log_root.addHandler(handler_stream)
+ log_root.info("rank_id: %d" % rank)
diff --git a/Deep3DFaceRecon_pytorch/models/base_model.py b/Deep3DFaceRecon_pytorch/models/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..36473ba5b2333eb4cdf66950bfaafa790fa40519
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/base_model.py
@@ -0,0 +1,321 @@
+"""This script defines the base network model for Deep3DFaceRecon_pytorch
+"""
+import os
+from abc import ABC
+from abc import abstractmethod
+from collections import OrderedDict
+
+import numpy as np
+import torch
+
+from . import networks
+
+
+class BaseModel(ABC):
+ """This class is an abstract base class (ABC) for models.
+ To create a subclass, you need to implement the following five functions:
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
+ -- : unpack data from dataset and apply preprocessing.
+ -- : produce intermediate results.
+ -- : calculate losses, gradients, and update network weights.
+ -- : (optionally) add model-specific options and set default options.
+ """
+
+ def __init__(self, opt):
+ """Initialize the BaseModel class.
+
+ Parameters:
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
+
+ When creating your custom class, you need to implement your own initialization.
+ In this fucntion, you should first call
+ Then, you need to define four lists:
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
+ -- self.model_names (str list): specify the images that you want to display and save.
+ -- self.visual_names (str list): define networks used in our training.
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
+ """
+ self.opt = opt
+ self.isTrain = opt.isTrain
+ self.device = torch.device("cpu")
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
+ self.loss_names = []
+ self.model_names = []
+ self.visual_names = []
+ self.parallel_names = []
+ self.optimizers = []
+ self.image_paths = []
+ self.metric = 0 # used for learning rate policy 'plateau'
+
+ @staticmethod
+ def dict_grad_hook_factory(add_func=lambda x: x):
+ saved_dict = dict()
+
+ def hook_gen(name):
+ def grad_hook(grad):
+ saved_vals = add_func(grad)
+ saved_dict[name] = saved_vals
+
+ return grad_hook
+
+ return hook_gen, saved_dict
+
+ @staticmethod
+ def modify_commandline_options(parser, is_train):
+ """Add new model-specific options, and rewrite default values for existing options.
+
+ Parameters:
+ parser -- original option parser
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+ """
+ return parser
+
+ @abstractmethod
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+
+ Parameters:
+ input (dict): includes the data itself and its metadata information.
+ """
+ pass
+
+ @abstractmethod
+ def forward(self):
+ """Run forward pass; called by both functions and ."""
+ pass
+
+ @abstractmethod
+ def optimize_parameters(self):
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
+ pass
+
+ def setup(self, opt):
+ """Load and print networks; create schedulers
+
+ Parameters:
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ """
+ if self.isTrain:
+ self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
+
+ if not self.isTrain or opt.continue_train:
+ load_suffix = opt.epoch
+ self.load_networks(load_suffix)
+
+ # self.print_networks(opt.verbose)
+
+ def parallelize(self, convert_sync_batchnorm=True):
+ if not self.opt.use_ddp:
+ for name in self.parallel_names:
+ if isinstance(name, str):
+ module = getattr(self, name)
+ setattr(self, name, module.to(self.device))
+ else:
+ for name in self.model_names:
+ if isinstance(name, str):
+ module = getattr(self, name)
+ if convert_sync_batchnorm:
+ module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module)
+ setattr(
+ self,
+ name,
+ torch.nn.parallel.DistributedDataParallel(
+ module.to(self.device),
+ device_ids=[self.device.index],
+ find_unused_parameters=True,
+ broadcast_buffers=True,
+ ),
+ )
+
+ # DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient.
+ for name in self.parallel_names:
+ if isinstance(name, str) and name not in self.model_names:
+ module = getattr(self, name)
+ setattr(self, name, module.to(self.device))
+
+ # put state_dict of optimizer to gpu device
+ if self.opt.phase != "test":
+ if self.opt.continue_train:
+ for optim in self.optimizers:
+ for state in optim.state.values():
+ for k, v in state.items():
+ if isinstance(v, torch.Tensor):
+ state[k] = v.to(self.device)
+
+ def data_dependent_initialize(self, data):
+ pass
+
+ def train(self):
+ """Make models train mode"""
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, name)
+ net.train()
+
+ def eval(self):
+ """Make models eval mode"""
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, name)
+ net.eval()
+
+ def test(self):
+ """Forward function used in test time.
+
+ This function wraps function in no_grad() so we don't save intermediate steps for backprop
+ It also calls to produce additional visualization results
+ """
+ with torch.no_grad():
+ self.forward()
+ self.compute_visuals()
+
+ def compute_visuals(self):
+ """Calculate additional output images for visdom and HTML visualization"""
+ pass
+
+ def get_image_paths(self, name="A"):
+ """Return image paths that are used to load current data"""
+ return self.image_paths if name == "A" else self.image_paths_B
+
+ def update_learning_rate(self):
+ """Update learning rates for all the networks; called at the end of every epoch"""
+ for scheduler in self.schedulers:
+ if self.opt.lr_policy == "plateau":
+ scheduler.step(self.metric)
+ else:
+ scheduler.step()
+
+ lr = self.optimizers[0].param_groups[0]["lr"]
+ print("learning rate = %.7f" % lr)
+
+ def get_current_visuals(self):
+ """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
+ visual_ret = OrderedDict()
+ for name in self.visual_names:
+ if isinstance(name, str):
+ visual_ret[name] = getattr(self, name)[:, :3, ...]
+ return visual_ret
+
+ def get_current_losses(self):
+ """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
+ errors_ret = OrderedDict()
+ for name in self.loss_names:
+ if isinstance(name, str):
+ errors_ret[name] = float(
+ getattr(self, "loss_" + name)
+ ) # float(...) works for both scalar tensor and float number
+ return errors_ret
+
+ def save_networks(self, epoch):
+ """Save all the networks to the disk.
+
+ Parameters:
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
+ """
+ if not os.path.isdir(self.save_dir):
+ os.makedirs(self.save_dir)
+
+ save_filename = "epoch_%s.pth" % (epoch)
+ save_path = os.path.join(self.save_dir, save_filename)
+
+ save_dict = {}
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, name)
+ if isinstance(net, torch.nn.DataParallel) or isinstance(net, torch.nn.parallel.DistributedDataParallel):
+ net = net.module
+ save_dict[name] = net.state_dict()
+
+ for i, optim in enumerate(self.optimizers):
+ save_dict["opt_%02d" % i] = optim.state_dict()
+
+ for i, sched in enumerate(self.schedulers):
+ save_dict["sched_%02d" % i] = sched.state_dict()
+
+ torch.save(save_dict, save_path)
+
+ def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
+ """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
+ key = keys[i]
+ if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
+ if module.__class__.__name__.startswith("InstanceNorm") and (key == "running_mean" or key == "running_var"):
+ if getattr(module, key) is None:
+ state_dict.pop(".".join(keys))
+ if module.__class__.__name__.startswith("InstanceNorm") and (key == "num_batches_tracked"):
+ state_dict.pop(".".join(keys))
+ else:
+ self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
+
+ def load_networks(self, epoch):
+ """Load all the networks from the disk.
+
+ Parameters:
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
+ """
+ if self.opt.isTrain and self.opt.pretrained_name is not None:
+ load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name)
+ else:
+ load_dir = self.save_dir
+ load_filename = "epoch_%s.pth" % (epoch)
+ load_path = os.path.join(load_dir, load_filename)
+ state_dict = torch.load(load_path, map_location=self.device)
+ print("loading the model from %s" % load_path)
+
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, name)
+ if isinstance(net, torch.nn.DataParallel):
+ net = net.module
+ net.load_state_dict(state_dict[name])
+
+ if self.opt.phase != "test":
+ if self.opt.continue_train:
+ print("loading the optim from %s" % load_path)
+ for i, optim in enumerate(self.optimizers):
+ optim.load_state_dict(state_dict["opt_%02d" % i])
+
+ try:
+ print("loading the sched from %s" % load_path)
+ for i, sched in enumerate(self.schedulers):
+ sched.load_state_dict(state_dict["sched_%02d" % i])
+ except:
+ print("Failed to load schedulers, set schedulers according to epoch count manually")
+ for i, sched in enumerate(self.schedulers):
+ sched.last_epoch = self.opt.epoch_count - 1
+
+ def print_networks(self, verbose):
+ """Print the total number of parameters in the network and (if verbose) network architecture
+
+ Parameters:
+ verbose (bool) -- if verbose: print the network architecture
+ """
+ print("---------- Networks initialized -------------")
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, name)
+ num_params = 0
+ for param in net.parameters():
+ num_params += param.numel()
+ if verbose:
+ print(net)
+ print("[Network %s] Total number of parameters : %.3f M" % (name, num_params / 1e6))
+ print("-----------------------------------------------")
+
+ def set_requires_grad(self, nets, requires_grad=False):
+ """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
+ Parameters:
+ nets (network list) -- a list of networks
+ requires_grad (bool) -- whether the networks require gradients or not
+ """
+ if not isinstance(nets, list):
+ nets = [nets]
+ for net in nets:
+ if net is not None:
+ for param in net.parameters():
+ param.requires_grad = requires_grad
+
+ def generate_visuals_for_evaluation(self, data, mode):
+ return {}
diff --git a/Deep3DFaceRecon_pytorch/models/bfm.py b/Deep3DFaceRecon_pytorch/models/bfm.py
new file mode 100644
index 0000000000000000000000000000000000000000..82fc25090f5997bb6f297cd0a8b4fae3ee99f715
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/bfm.py
@@ -0,0 +1,291 @@
+"""This script defines the parametric 3d face model for Deep3DFaceRecon_pytorch
+"""
+import os
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from scipy.io import loadmat
+
+from Deep3DFaceRecon_pytorch.util.load_mats import transferBFM09
+
+
+def perspective_projection(focal, center):
+ # return p.T (N, 3) @ (3, 3)
+ return np.array([focal, 0, center, 0, focal, center, 0, 0, 1]).reshape([3, 3]).astype(np.float32).transpose()
+
+
+class SH:
+ def __init__(self):
+ self.a = [np.pi, 2 * np.pi / np.sqrt(3.0), 2 * np.pi / np.sqrt(8.0)]
+ self.c = [1 / np.sqrt(4 * np.pi), np.sqrt(3.0) / np.sqrt(4 * np.pi), 3 * np.sqrt(5.0) / np.sqrt(12 * np.pi)]
+
+
+class ParametricFaceModel:
+ def __init__(
+ self,
+ bfm_folder="./BFM",
+ recenter=True,
+ camera_distance=10.0,
+ init_lit=np.array([0.8, 0, 0, 0, 0, 0, 0, 0, 0]),
+ focal=1015.0,
+ center=112.0,
+ is_train=True,
+ default_name="BFM_model_front.mat",
+ ):
+
+ if not os.path.isfile(os.path.join(bfm_folder, default_name)):
+ transferBFM09(bfm_folder)
+ model = loadmat(os.path.join(bfm_folder, default_name))
+ # mean face shape. [3*N,1]
+ self.mean_shape = model["meanshape"].astype(np.float32)
+ # identity basis. [3*N,80]
+ self.id_base = model["idBase"].astype(np.float32)
+ # expression basis. [3*N,64]
+ self.exp_base = model["exBase"].astype(np.float32)
+ # mean face texture. [3*N,1] (0-255)
+ self.mean_tex = model["meantex"].astype(np.float32)
+ # texture basis. [3*N,80]
+ self.tex_base = model["texBase"].astype(np.float32)
+ # face indices for each vertex that lies in. starts from 0. [N,8]
+ self.point_buf = model["point_buf"].astype(np.int64) - 1
+ # vertex indices for each face. starts from 0. [F,3]
+ self.face_buf = model["tri"].astype(np.int64) - 1
+ # vertex indices for 68 landmarks. starts from 0. [68,1]
+ self.keypoints = np.squeeze(model["keypoints"]).astype(np.int64) - 1
+
+ if is_train:
+ # vertex indices for small face region to compute photometric error. starts from 0.
+ self.front_mask = np.squeeze(model["frontmask2_idx"]).astype(np.int64) - 1
+ # vertex indices for each face from small face region. starts from 0. [f,3]
+ self.front_face_buf = model["tri_mask2"].astype(np.int64) - 1
+ # vertex indices for pre-defined skin region to compute reflectance loss
+ self.skin_mask = np.squeeze(model["skinmask"])
+
+ if recenter:
+ mean_shape = self.mean_shape.reshape([-1, 3])
+ mean_shape = mean_shape - np.mean(mean_shape, axis=0, keepdims=True)
+ self.mean_shape = mean_shape.reshape([-1, 1])
+
+ self.persc_proj = perspective_projection(focal, center)
+ self.device = "cpu"
+ self.camera_distance = camera_distance
+ self.SH = SH()
+ self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32)
+
+ def to(self, device):
+ self.device = device
+ for key, value in self.__dict__.items():
+ if type(value).__module__ == np.__name__:
+ setattr(self, key, torch.tensor(value).to(device))
+ elif type(value).__module__ == torch.__name__ and isinstance(value, torch.Tensor):
+ setattr(self, key, value.to(device))
+
+ def compute_shape(self, id_coeff, exp_coeff):
+ """
+ Return:
+ face_shape -- torch.tensor, size (B, N, 3)
+
+ Parameters:
+ id_coeff -- torch.tensor, size (B, 80), identity coeffs
+ exp_coeff -- torch.tensor, size (B, 64), expression coeffs
+ """
+ batch_size = id_coeff.shape[0]
+ id_part = torch.einsum("ij,aj->ai", self.id_base, id_coeff)
+ exp_part = torch.einsum("ij,aj->ai", self.exp_base, exp_coeff)
+ face_shape = id_part + exp_part + self.mean_shape.reshape([1, -1])
+ return face_shape.reshape([batch_size, -1, 3])
+
+ def compute_texture(self, tex_coeff, normalize=True):
+ """
+ Return:
+ face_texture -- torch.tensor, size (B, N, 3), in RGB order, range (0, 1.)
+
+ Parameters:
+ tex_coeff -- torch.tensor, size (B, 80)
+ """
+ batch_size = tex_coeff.shape[0]
+ face_texture = torch.einsum("ij,aj->ai", self.tex_base, tex_coeff) + self.mean_tex
+ if normalize:
+ face_texture = face_texture / 255.0
+ return face_texture.reshape([batch_size, -1, 3])
+
+ def compute_norm(self, face_shape):
+ """
+ Return:
+ vertex_norm -- torch.tensor, size (B, N, 3)
+
+ Parameters:
+ face_shape -- torch.tensor, size (B, N, 3)
+ """
+
+ v1 = face_shape[:, self.face_buf[:, 0]]
+ v2 = face_shape[:, self.face_buf[:, 1]]
+ v3 = face_shape[:, self.face_buf[:, 2]]
+ e1 = v1 - v2
+ e2 = v2 - v3
+ face_norm = torch.cross(e1, e2, dim=-1)
+ face_norm = F.normalize(face_norm, dim=-1, p=2)
+ face_norm = torch.cat([face_norm, torch.zeros(face_norm.shape[0], 1, 3).to(self.device)], dim=1)
+
+ vertex_norm = torch.sum(face_norm[:, self.point_buf], dim=2)
+ vertex_norm = F.normalize(vertex_norm, dim=-1, p=2)
+ return vertex_norm
+
+ def compute_color(self, face_texture, face_norm, gamma):
+ """
+ Return:
+ face_color -- torch.tensor, size (B, N, 3), range (0, 1.)
+
+ Parameters:
+ face_texture -- torch.tensor, size (B, N, 3), from texture model, range (0, 1.)
+ face_norm -- torch.tensor, size (B, N, 3), rotated face normal
+ gamma -- torch.tensor, size (B, 27), SH coeffs
+ """
+ batch_size = gamma.shape[0]
+ v_num = face_texture.shape[1]
+ a, c = self.SH.a, self.SH.c
+ gamma = gamma.reshape([batch_size, 3, 9])
+ gamma = gamma + self.init_lit
+ gamma = gamma.permute(0, 2, 1)
+ Y = torch.cat(
+ [
+ a[0] * c[0] * torch.ones_like(face_norm[..., :1]).to(self.device),
+ -a[1] * c[1] * face_norm[..., 1:2],
+ a[1] * c[1] * face_norm[..., 2:],
+ -a[1] * c[1] * face_norm[..., :1],
+ a[2] * c[2] * face_norm[..., :1] * face_norm[..., 1:2],
+ -a[2] * c[2] * face_norm[..., 1:2] * face_norm[..., 2:],
+ 0.5 * a[2] * c[2] / np.sqrt(3.0) * (3 * face_norm[..., 2:] ** 2 - 1),
+ -a[2] * c[2] * face_norm[..., :1] * face_norm[..., 2:],
+ 0.5 * a[2] * c[2] * (face_norm[..., :1] ** 2 - face_norm[..., 1:2] ** 2),
+ ],
+ dim=-1,
+ )
+ r = Y @ gamma[..., :1]
+ g = Y @ gamma[..., 1:2]
+ b = Y @ gamma[..., 2:]
+ face_color = torch.cat([r, g, b], dim=-1) * face_texture
+ return face_color
+
+ def compute_rotation(self, angles):
+ """
+ Return:
+ rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat
+
+ Parameters:
+ angles -- torch.tensor, size (B, 3), radian
+ """
+
+ batch_size = angles.shape[0]
+ ones = torch.ones([batch_size, 1]).to(self.device)
+ zeros = torch.zeros([batch_size, 1]).to(self.device)
+ x, y, z = (
+ angles[:, :1],
+ angles[:, 1:2],
+ angles[:, 2:],
+ )
+
+ rot_x = torch.cat(
+ [ones, zeros, zeros, zeros, torch.cos(x), -torch.sin(x), zeros, torch.sin(x), torch.cos(x)], dim=1
+ ).reshape([batch_size, 3, 3])
+
+ rot_y = torch.cat(
+ [torch.cos(y), zeros, torch.sin(y), zeros, ones, zeros, -torch.sin(y), zeros, torch.cos(y)], dim=1
+ ).reshape([batch_size, 3, 3])
+
+ rot_z = torch.cat(
+ [torch.cos(z), -torch.sin(z), zeros, torch.sin(z), torch.cos(z), zeros, zeros, zeros, ones], dim=1
+ ).reshape([batch_size, 3, 3])
+
+ rot = rot_z @ rot_y @ rot_x
+ return rot.permute(0, 2, 1)
+
+ def to_camera(self, face_shape):
+ face_shape[..., -1] = self.camera_distance - face_shape[..., -1]
+ return face_shape
+
+ def to_image(self, face_shape):
+ """
+ Return:
+ face_proj -- torch.tensor, size (B, N, 2), y direction is opposite to v direction
+
+ Parameters:
+ face_shape -- torch.tensor, size (B, N, 3)
+ """
+ # to image_plane
+ face_proj = face_shape @ self.persc_proj
+ face_proj = face_proj[..., :2] / face_proj[..., 2:]
+
+ return face_proj
+
+ def transform(self, face_shape, rot, trans):
+ """
+ Return:
+ face_shape -- torch.tensor, size (B, N, 3) pts @ rot + trans
+
+ Parameters:
+ face_shape -- torch.tensor, size (B, N, 3)
+ rot -- torch.tensor, size (B, 3, 3)
+ trans -- torch.tensor, size (B, 3)
+ """
+ return face_shape @ rot + trans.unsqueeze(1)
+
+ def get_landmarks(self, face_proj):
+ """
+ Return:
+ face_lms -- torch.tensor, size (B, 68, 2)
+
+ Parameters:
+ face_proj -- torch.tensor, size (B, N, 2)
+ """
+ return face_proj[:, self.keypoints]
+
+ def split_coeff(self, coeffs):
+ """
+ Return:
+ coeffs_dict -- a dict of torch.tensors
+
+ Parameters:
+ coeffs -- torch.tensor, size (B, 256)
+ """
+ id_coeffs = coeffs[:, :80]
+ exp_coeffs = coeffs[:, 80:144]
+ tex_coeffs = coeffs[:, 144:224]
+ angles = coeffs[:, 224:227]
+ gammas = coeffs[:, 227:254]
+ translations = coeffs[:, 254:]
+ return {
+ "id": id_coeffs,
+ "exp": exp_coeffs,
+ "tex": tex_coeffs,
+ "angle": angles,
+ "gamma": gammas,
+ "trans": translations,
+ }
+
+ def compute_for_render(self, coeffs):
+ """
+ Return:
+ face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
+ face_color -- torch.tensor, size (B, N, 3), in RGB order
+ landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
+ Parameters:
+ coeffs -- torch.tensor, size (B, 257)
+ """
+ coef_dict = self.split_coeff(coeffs)
+ face_shape = self.compute_shape(coef_dict["id"], coef_dict["exp"])
+ rotation = self.compute_rotation(coef_dict["angle"])
+
+ face_shape_transformed = self.transform(face_shape, rotation, coef_dict["trans"])
+ face_vertex = self.to_camera(face_shape_transformed)
+
+ face_proj = self.to_image(face_vertex)
+ landmark = self.get_landmarks(face_proj)
+
+ face_texture = self.compute_texture(coef_dict["tex"])
+ face_norm = self.compute_norm(face_shape)
+ face_norm_roted = face_norm @ rotation
+ face_color = self.compute_color(face_texture, face_norm_roted, coef_dict["gamma"])
+
+ return face_vertex, face_texture, face_color, landmark
diff --git a/Deep3DFaceRecon_pytorch/models/facerecon_model.py b/Deep3DFaceRecon_pytorch/models/facerecon_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..252b4eda2eb8b8098b22da72798a9843dc920b7a
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/facerecon_model.py
@@ -0,0 +1,268 @@
+"""This script defines the face reconstruction model for Deep3DFaceRecon_pytorch
+"""
+import numpy as np
+import torch
+import trimesh
+from scipy.io import savemat
+from util import util
+from util.nvdiffrast import MeshRenderer
+from util.preprocess import estimate_norm_torch
+
+from . import networks
+from .base_model import BaseModel
+from .bfm import ParametricFaceModel
+from .losses import landmark_loss
+from .losses import perceptual_loss
+from .losses import photo_loss
+from .losses import reflectance_loss
+from .losses import reg_loss
+
+
+class FaceReconModel(BaseModel):
+ @staticmethod
+ def modify_commandline_options(parser, is_train=True):
+ """Configures options specific for CUT model"""
+ # net structure and parameters
+ parser.add_argument(
+ "--net_recon",
+ type=str,
+ default="resnet50",
+ choices=["resnet18", "resnet34", "resnet50"],
+ help="network structure",
+ )
+ parser.add_argument("--init_path", type=str, default="checkpoints/init_model/resnet50-0676ba61.pth")
+ parser.add_argument(
+ "--use_last_fc",
+ type=util.str2bool,
+ nargs="?",
+ const=True,
+ default=False,
+ help="zero initialize the last fc",
+ )
+ parser.add_argument("--bfm_folder", type=str, default="BFM")
+ parser.add_argument("--bfm_model", type=str, default="BFM_model_front.mat", help="bfm model")
+
+ # renderer parameters
+ parser.add_argument("--focal", type=float, default=1015.0)
+ parser.add_argument("--center", type=float, default=112.0)
+ parser.add_argument("--camera_d", type=float, default=10.0)
+ parser.add_argument("--z_near", type=float, default=5.0)
+ parser.add_argument("--z_far", type=float, default=15.0)
+ parser.add_argument(
+ "--use_opengl", type=util.str2bool, nargs="?", const=True, default=True, help="use opengl context or not"
+ )
+
+ if is_train:
+ # training parameters
+ parser.add_argument(
+ "--net_recog",
+ type=str,
+ default="r50",
+ choices=["r18", "r43", "r50"],
+ help="face recog network structure",
+ )
+ parser.add_argument(
+ "--net_recog_path", type=str, default="checkpoints/recog_model/ms1mv3_arcface_r50_fp16/backbone.pth"
+ )
+ parser.add_argument(
+ "--use_crop_face",
+ type=util.str2bool,
+ nargs="?",
+ const=True,
+ default=False,
+ help="use crop mask for photo loss",
+ )
+ parser.add_argument(
+ "--use_predef_M",
+ type=util.str2bool,
+ nargs="?",
+ const=True,
+ default=False,
+ help="use predefined M for predicted face",
+ )
+
+ # augmentation parameters
+ parser.add_argument("--shift_pixs", type=float, default=10.0, help="shift pixels")
+ parser.add_argument("--scale_delta", type=float, default=0.1, help="delta scale factor")
+ parser.add_argument("--rot_angle", type=float, default=10.0, help="rot angles, degree")
+
+ # loss weights
+ parser.add_argument("--w_feat", type=float, default=0.2, help="weight for feat loss")
+ parser.add_argument("--w_color", type=float, default=1.92, help="weight for loss loss")
+ parser.add_argument("--w_reg", type=float, default=3.0e-4, help="weight for reg loss")
+ parser.add_argument("--w_id", type=float, default=1.0, help="weight for id_reg loss")
+ parser.add_argument("--w_exp", type=float, default=0.8, help="weight for exp_reg loss")
+ parser.add_argument("--w_tex", type=float, default=1.7e-2, help="weight for tex_reg loss")
+ parser.add_argument("--w_gamma", type=float, default=10.0, help="weight for gamma loss")
+ parser.add_argument("--w_lm", type=float, default=1.6e-3, help="weight for lm loss")
+ parser.add_argument("--w_reflc", type=float, default=5.0, help="weight for reflc loss")
+
+ opt, _ = parser.parse_known_args()
+ parser.set_defaults(focal=1015.0, center=112.0, camera_d=10.0, use_last_fc=False, z_near=5.0, z_far=15.0)
+ if is_train:
+ parser.set_defaults(use_crop_face=True, use_predef_M=False)
+ return parser
+
+ def __init__(self, opt):
+ """Initialize this model class.
+
+ Parameters:
+ opt -- training/test options
+
+ A few things can be done here.
+ - (required) call the initialization function of BaseModel
+ - define loss function, visualization images, model names, and optimizers
+ """
+ BaseModel.__init__(self, opt) # call the initialization method of BaseModel
+
+ self.visual_names = ["output_vis"]
+ self.model_names = ["net_recon"]
+ self.parallel_names = self.model_names + ["renderer"]
+
+ self.net_recon = networks.define_net_recon(
+ net_recon=opt.net_recon, use_last_fc=opt.use_last_fc, init_path=opt.init_path
+ )
+
+ self.facemodel = ParametricFaceModel(
+ bfm_folder=opt.bfm_folder,
+ camera_distance=opt.camera_d,
+ focal=opt.focal,
+ center=opt.center,
+ is_train=self.isTrain,
+ default_name=opt.bfm_model,
+ )
+
+ fov = 2 * np.arctan(opt.center / opt.focal) * 180 / np.pi
+ self.renderer = MeshRenderer(
+ rasterize_fov=fov,
+ znear=opt.z_near,
+ zfar=opt.z_far,
+ rasterize_size=int(2 * opt.center),
+ use_opengl=opt.use_opengl,
+ )
+
+ if self.isTrain:
+ self.loss_names = ["all", "feat", "color", "lm", "reg", "gamma", "reflc"]
+
+ self.net_recog = networks.define_net_recog(net_recog=opt.net_recog, pretrained_path=opt.net_recog_path)
+ # loss func name: (compute_%s_loss) % loss_name
+ self.compute_feat_loss = perceptual_loss
+ self.comupte_color_loss = photo_loss
+ self.compute_lm_loss = landmark_loss
+ self.compute_reg_loss = reg_loss
+ self.compute_reflc_loss = reflectance_loss
+
+ self.optimizer = torch.optim.Adam(self.net_recon.parameters(), lr=opt.lr)
+ self.optimizers = [self.optimizer]
+ self.parallel_names += ["net_recog"]
+ # Our program will automatically call to define schedulers, load networks, and print networks
+
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+
+ Parameters:
+ input: a dictionary that contains the data itself and its metadata information.
+ """
+ self.input_img = input["imgs"].to(self.device)
+ self.atten_mask = input["msks"].to(self.device) if "msks" in input else None
+ self.gt_lm = input["lms"].to(self.device) if "lms" in input else None
+ self.trans_m = input["M"].to(self.device) if "M" in input else None
+ self.image_paths = input["im_paths"] if "im_paths" in input else None
+
+ def forward(self):
+ output_coeff = self.net_recon(self.input_img)
+ self.facemodel.to(self.device)
+ self.pred_vertex, self.pred_tex, self.pred_color, self.pred_lm = self.facemodel.compute_for_render(output_coeff)
+ self.pred_mask, _, self.pred_face = self.renderer(
+ self.pred_vertex, self.facemodel.face_buf, feat=self.pred_color
+ )
+
+ self.pred_coeffs_dict = self.facemodel.split_coeff(output_coeff)
+
+ def compute_losses(self):
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
+
+ assert self.net_recog.training == False
+ trans_m = self.trans_m
+ if not self.opt.use_predef_M:
+ trans_m = estimate_norm_torch(self.pred_lm, self.input_img.shape[-2])
+
+ pred_feat = self.net_recog(self.pred_face, trans_m)
+ gt_feat = self.net_recog(self.input_img, self.trans_m)
+ self.loss_feat = self.opt.w_feat * self.compute_feat_loss(pred_feat, gt_feat)
+
+ face_mask = self.pred_mask
+ if self.opt.use_crop_face:
+ face_mask, _, _ = self.renderer(self.pred_vertex, self.facemodel.front_face_buf)
+
+ face_mask = face_mask.detach()
+ self.loss_color = self.opt.w_color * self.comupte_color_loss(
+ self.pred_face, self.input_img, self.atten_mask * face_mask
+ )
+
+ loss_reg, loss_gamma = self.compute_reg_loss(self.pred_coeffs_dict, self.opt)
+ self.loss_reg = self.opt.w_reg * loss_reg
+ self.loss_gamma = self.opt.w_gamma * loss_gamma
+
+ self.loss_lm = self.opt.w_lm * self.compute_lm_loss(self.pred_lm, self.gt_lm)
+
+ self.loss_reflc = self.opt.w_reflc * self.compute_reflc_loss(self.pred_tex, self.facemodel.skin_mask)
+
+ self.loss_all = (
+ self.loss_feat + self.loss_color + self.loss_reg + self.loss_gamma + self.loss_lm + self.loss_reflc
+ )
+
+ def optimize_parameters(self, isTrain=True):
+ self.forward()
+ self.compute_losses()
+ """Update network weights; it will be called in every training iteration."""
+ if isTrain:
+ self.optimizer.zero_grad()
+ self.loss_all.backward()
+ self.optimizer.step()
+
+ def compute_visuals(self):
+ with torch.no_grad():
+ input_img_numpy = 255.0 * self.input_img.detach().cpu().permute(0, 2, 3, 1).numpy()
+ output_vis = self.pred_face * self.pred_mask + (1 - self.pred_mask) * self.input_img
+ output_vis_numpy_raw = 255.0 * output_vis.detach().cpu().permute(0, 2, 3, 1).numpy()
+
+ if self.gt_lm is not None:
+ gt_lm_numpy = self.gt_lm.cpu().numpy()
+ pred_lm_numpy = self.pred_lm.detach().cpu().numpy()
+ output_vis_numpy = util.draw_landmarks(output_vis_numpy_raw, gt_lm_numpy, "b")
+ output_vis_numpy = util.draw_landmarks(output_vis_numpy, pred_lm_numpy, "r")
+
+ output_vis_numpy = np.concatenate((input_img_numpy, output_vis_numpy_raw, output_vis_numpy), axis=-2)
+ else:
+ output_vis_numpy = np.concatenate((input_img_numpy, output_vis_numpy_raw), axis=-2)
+
+ self.output_vis = (
+ torch.tensor(output_vis_numpy / 255.0, dtype=torch.float32).permute(0, 3, 1, 2).to(self.device)
+ )
+
+ def save_mesh(self, name):
+
+ recon_shape = self.pred_vertex # get reconstructed shape
+ recon_shape[..., -1] = 10 - recon_shape[..., -1] # from camera space to world space
+ recon_shape = recon_shape.cpu().numpy()[0]
+ recon_color = self.pred_color
+ recon_color = recon_color.cpu().numpy()[0]
+ tri = self.facemodel.face_buf.cpu().numpy()
+ mesh = trimesh.Trimesh(
+ vertices=recon_shape,
+ faces=tri,
+ vertex_colors=np.clip(255.0 * recon_color, 0, 255).astype(np.uint8),
+ process=False,
+ )
+ mesh.export(name)
+
+ def save_coeff(self, name):
+
+ pred_coeffs = {key: self.pred_coeffs_dict[key].cpu().numpy() for key in self.pred_coeffs_dict}
+ pred_lm = self.pred_lm.cpu().numpy()
+ pred_lm = np.stack(
+ [pred_lm[:, :, 0], self.input_img.shape[2] - 1 - pred_lm[:, :, 1]], axis=2
+ ) # transfer to image coordinate
+ pred_coeffs["lm68"] = pred_lm
+ savemat(name, pred_coeffs)
diff --git a/Deep3DFaceRecon_pytorch/models/losses.py b/Deep3DFaceRecon_pytorch/models/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3cb3fd778adedf31acbc3ff01018e9efb99d65b
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/losses.py
@@ -0,0 +1,121 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from kornia.geometry import warp_affine
+
+
+def resize_n_crop(image, M, dsize=112):
+ # image: (b, c, h, w)
+ # M : (b, 2, 3)
+ return warp_affine(image, M, dsize=(dsize, dsize))
+
+
+### perceptual level loss
+class PerceptualLoss(nn.Module):
+ def __init__(self, recog_net, input_size=112):
+ super(PerceptualLoss, self).__init__()
+ self.recog_net = recog_net
+ self.preprocess = lambda x: 2 * x - 1
+ self.input_size = input_size
+
+ def forward(imageA, imageB, M):
+ """
+ 1 - cosine distance
+ Parameters:
+ imageA --torch.tensor (B, 3, H, W), range (0, 1) , RGB order
+ imageB --same as imageA
+ """
+
+ imageA = self.preprocess(resize_n_crop(imageA, M, self.input_size))
+ imageB = self.preprocess(resize_n_crop(imageB, M, self.input_size))
+
+ # freeze bn
+ self.recog_net.eval()
+
+ id_featureA = F.normalize(self.recog_net(imageA), dim=-1, p=2)
+ id_featureB = F.normalize(self.recog_net(imageB), dim=-1, p=2)
+ cosine_d = torch.sum(id_featureA * id_featureB, dim=-1)
+ # assert torch.sum((cosine_d > 1).float()) == 0
+ return torch.sum(1 - cosine_d) / cosine_d.shape[0]
+
+
+def perceptual_loss(id_featureA, id_featureB):
+ cosine_d = torch.sum(id_featureA * id_featureB, dim=-1)
+ # assert torch.sum((cosine_d > 1).float()) == 0
+ return torch.sum(1 - cosine_d) / cosine_d.shape[0]
+
+
+### image level loss
+def photo_loss(imageA, imageB, mask, eps=1e-6):
+ """
+ l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur)
+ Parameters:
+ imageA --torch.tensor (B, 3, H, W), range (0, 1), RGB order
+ imageB --same as imageA
+ """
+ loss = torch.sqrt(eps + torch.sum((imageA - imageB) ** 2, dim=1, keepdims=True)) * mask
+ loss = torch.sum(loss) / torch.max(torch.sum(mask), torch.tensor(1.0).to(mask.device))
+ return loss
+
+
+def landmark_loss(predict_lm, gt_lm, weight=None):
+ """
+ weighted mse loss
+ Parameters:
+ predict_lm --torch.tensor (B, 68, 2)
+ gt_lm --torch.tensor (B, 68, 2)
+ weight --numpy.array (1, 68)
+ """
+ if not weight:
+ weight = np.ones([68])
+ weight[28:31] = 20
+ weight[-8:] = 20
+ weight = np.expand_dims(weight, 0)
+ weight = torch.tensor(weight).to(predict_lm.device)
+ loss = torch.sum((predict_lm - gt_lm) ** 2, dim=-1) * weight
+ loss = torch.sum(loss) / (predict_lm.shape[0] * predict_lm.shape[1])
+ return loss
+
+
+### regulization
+def reg_loss(coeffs_dict, opt=None):
+ """
+ l2 norm without the sqrt, from yu's implementation (mse)
+ tf.nn.l2_loss https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss
+ Parameters:
+ coeffs_dict -- a dict of torch.tensors , keys: id, exp, tex, angle, gamma, trans
+
+ """
+ # coefficient regularization to ensure plausible 3d faces
+ if opt:
+ w_id, w_exp, w_tex = opt.w_id, opt.w_exp, opt.w_tex
+ else:
+ w_id, w_exp, w_tex = 1, 1, 1, 1
+ creg_loss = (
+ w_id * torch.sum(coeffs_dict["id"] ** 2)
+ + w_exp * torch.sum(coeffs_dict["exp"] ** 2)
+ + w_tex * torch.sum(coeffs_dict["tex"] ** 2)
+ )
+ creg_loss = creg_loss / coeffs_dict["id"].shape[0]
+
+ # gamma regularization to ensure a nearly-monochromatic light
+ gamma = coeffs_dict["gamma"].reshape([-1, 3, 9])
+ gamma_mean = torch.mean(gamma, dim=1, keepdims=True)
+ gamma_loss = torch.mean((gamma - gamma_mean) ** 2)
+
+ return creg_loss, gamma_loss
+
+
+def reflectance_loss(texture, mask):
+ """
+ minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo
+ Parameters:
+ texture --torch.tensor, (B, N, 3)
+ mask --torch.tensor, (N), 1 or 0
+
+ """
+ mask = mask.reshape([1, mask.shape[0], 1])
+ texture_mean = torch.sum(mask * texture, dim=1, keepdims=True) / torch.sum(mask)
+ loss = torch.sum(((texture - texture_mean) * mask) ** 2) / (texture.shape[0] * torch.sum(mask))
+ return loss
diff --git a/Deep3DFaceRecon_pytorch/models/networks.py b/Deep3DFaceRecon_pytorch/models/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..610f146f683849640aba3c3eaff0c14beade6732
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/networks.py
@@ -0,0 +1,547 @@
+"""This script defines deep neural networks for Deep3DFaceRecon_pytorch
+"""
+import functools
+import os
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from torch.nn import init
+from torch.optim import lr_scheduler
+
+try:
+ from torch.hub import load_state_dict_from_url
+except ImportError:
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
+from typing import Type, Any, Callable, Union, List, Optional
+from .arcface_torch.backbones import get_model
+from kornia.geometry import warp_affine
+
+
+def resize_n_crop(image, M, dsize=112):
+ # image: (b, c, h, w)
+ # M : (b, 2, 3)
+ return warp_affine(image, M, dsize=(dsize, dsize))
+
+
+def filter_state_dict(state_dict, remove_name="fc"):
+ new_state_dict = {}
+ for key in state_dict:
+ if remove_name in key:
+ continue
+ new_state_dict[key] = state_dict[key]
+ return new_state_dict
+
+
+def get_scheduler(optimizer, opt):
+ """Return a learning rate scheduler
+
+ Parameters:
+ optimizer -- the optimizer of the network
+ opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
+ opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
+
+ For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
+ See https://pytorch.org/docs/stable/optim.html for more details.
+ """
+ if opt.lr_policy == "linear":
+
+ def lambda_rule(epoch):
+ lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs + 1)
+ return lr_l
+
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
+ elif opt.lr_policy == "step":
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_epochs, gamma=0.2)
+ elif opt.lr_policy == "plateau":
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, threshold=0.01, patience=5)
+ elif opt.lr_policy == "cosine":
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
+ else:
+ return NotImplementedError("learning rate policy [%s] is not implemented", opt.lr_policy)
+ return scheduler
+
+
+def define_net_recon(net_recon, use_last_fc=False, init_path=None):
+ return ReconNetWrapper(net_recon, use_last_fc=use_last_fc, init_path=init_path)
+
+
+def define_net_recog(net_recog, pretrained_path=None):
+ net = RecogNetWrapper(net_recog=net_recog, pretrained_path=pretrained_path)
+ net.eval()
+ return net
+
+
+class ReconNetWrapper(nn.Module):
+ fc_dim = 257
+
+ def __init__(self, net_recon, use_last_fc=False, init_path=None):
+ super(ReconNetWrapper, self).__init__()
+ self.use_last_fc = use_last_fc
+ if net_recon not in func_dict:
+ return NotImplementedError("network [%s] is not implemented", net_recon)
+ func, last_dim = func_dict[net_recon]
+ backbone = func(use_last_fc=use_last_fc, num_classes=self.fc_dim)
+ if init_path and os.path.isfile(init_path):
+ state_dict = filter_state_dict(torch.load(init_path, map_location="cpu"))
+ backbone.load_state_dict(state_dict)
+ print("loading init net_recon %s from %s" % (net_recon, init_path))
+ self.backbone = backbone
+ if not use_last_fc:
+ self.final_layers = nn.ModuleList(
+ [
+ conv1x1(last_dim, 80, bias=True), # id layer
+ conv1x1(last_dim, 64, bias=True), # exp layer
+ conv1x1(last_dim, 80, bias=True), # tex layer
+ conv1x1(last_dim, 3, bias=True), # angle layer
+ conv1x1(last_dim, 27, bias=True), # gamma layer
+ conv1x1(last_dim, 2, bias=True), # tx, ty
+ conv1x1(last_dim, 1, bias=True), # tz
+ ]
+ )
+ for m in self.final_layers:
+ nn.init.constant_(m.weight, 0.0)
+ nn.init.constant_(m.bias, 0.0)
+
+ def forward(self, x):
+ x = self.backbone(x)
+ if not self.use_last_fc:
+ output = []
+ for layer in self.final_layers:
+ output.append(layer(x))
+ x = torch.flatten(torch.cat(output, dim=1), 1)
+ return x
+
+
+class RecogNetWrapper(nn.Module):
+ def __init__(self, net_recog, pretrained_path=None, input_size=112):
+ super(RecogNetWrapper, self).__init__()
+ net = get_model(name=net_recog, fp16=False)
+ if pretrained_path:
+ state_dict = torch.load(pretrained_path, map_location="cpu")
+ net.load_state_dict(state_dict)
+ print("loading pretrained net_recog %s from %s" % (net_recog, pretrained_path))
+ for param in net.parameters():
+ param.requires_grad = False
+ self.net = net
+ self.preprocess = lambda x: 2 * x - 1
+ self.input_size = input_size
+
+ def forward(self, image, M):
+ image = self.preprocess(resize_n_crop(image, M, self.input_size))
+ id_feature = F.normalize(self.net(image), dim=-1, p=2)
+ return id_feature
+
+
+# adapted from https://github.com/pytorch/vision/edit/master/torchvision/models/resnet.py
+__all__ = [
+ "ResNet",
+ "resnet18",
+ "resnet34",
+ "resnet50",
+ "resnet101",
+ "resnet152",
+ "resnext50_32x4d",
+ "resnext101_32x8d",
+ "wide_resnet50_2",
+ "wide_resnet101_2",
+]
+
+
+model_urls = {
+ "resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
+ "resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
+ "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
+ "resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
+ "resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
+ "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
+ "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
+ "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
+ "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
+}
+
+
+def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
+ """3x3 convolution with padding"""
+ return nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ groups=groups,
+ bias=False,
+ dilation=dilation,
+ )
+
+
+def conv1x1(in_planes: int, out_planes: int, stride: int = 1, bias: bool = False) -> nn.Conv2d:
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias)
+
+
+class BasicBlock(nn.Module):
+ expansion: int = 1
+
+ def __init__(
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ downsample: Optional[nn.Module] = None,
+ groups: int = 1,
+ base_width: int = 64,
+ dilation: int = 1,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ ) -> None:
+ super(BasicBlock, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ if groups != 1 or base_width != 64:
+ raise ValueError("BasicBlock only supports groups=1 and base_width=64")
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = norm_layer(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = norm_layer(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x: Tensor) -> Tensor:
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
+
+ expansion: int = 4
+
+ def __init__(
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ downsample: Optional[nn.Module] = None,
+ groups: int = 1,
+ base_width: int = 64,
+ dilation: int = 1,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ ) -> None:
+ super(Bottleneck, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ width = int(planes * (base_width / 64.0)) * groups
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv1x1(inplanes, width)
+ self.bn1 = norm_layer(width)
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
+ self.bn2 = norm_layer(width)
+ self.conv3 = conv1x1(width, planes * self.expansion)
+ self.bn3 = norm_layer(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x: Tensor) -> Tensor:
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+ def __init__(
+ self,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ layers: List[int],
+ num_classes: int = 1000,
+ zero_init_residual: bool = False,
+ use_last_fc: bool = False,
+ groups: int = 1,
+ width_per_group: int = 64,
+ replace_stride_with_dilation: Optional[List[bool]] = None,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ ) -> None:
+ super(ResNet, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self._norm_layer = norm_layer
+
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ # each element in the tuple indicates if we should replace
+ # the 2x2 stride with a dilated convolution instead
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError(
+ "replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
+ )
+ self.use_last_fc = use_last_fc
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = norm_layer(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+
+ if self.use_last_fc:
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ # Zero-initialize the last BN in each residual branch,
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
+ elif isinstance(m, BasicBlock):
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
+
+ def _make_layer(
+ self,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ planes: int,
+ blocks: int,
+ stride: int = 1,
+ dilate: bool = False,
+ ) -> nn.Sequential:
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(
+ block(
+ self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
+ )
+ )
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(
+ block(
+ self.inplanes,
+ planes,
+ groups=self.groups,
+ base_width=self.base_width,
+ dilation=self.dilation,
+ norm_layer=norm_layer,
+ )
+ )
+
+ return nn.Sequential(*layers)
+
+ def _forward_impl(self, x: Tensor) -> Tensor:
+ # See note [TorchScript super()]
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.avgpool(x)
+ if self.use_last_fc:
+ x = torch.flatten(x, 1)
+ x = self.fc(x)
+ return x
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self._forward_impl(x)
+
+
+def _resnet(
+ arch: str,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ layers: List[int],
+ pretrained: bool,
+ progress: bool,
+ **kwargs: Any
+) -> ResNet:
+ model = ResNet(block, layers, **kwargs)
+ if pretrained:
+ state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
+ model.load_state_dict(state_dict)
+ return model
+
+
+def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-18 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
+
+
+def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-34 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs)
+
+
+def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-50 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
+
+
+def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-101 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet("resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
+
+
+def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-152 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet("resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs)
+
+
+def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNeXt-50 32x4d model from
+ `"Aggregated Residual Transformation for Deep Neural Networks" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs["groups"] = 32
+ kwargs["width_per_group"] = 4
+ return _resnet("resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
+
+
+def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNeXt-101 32x8d model from
+ `"Aggregated Residual Transformation for Deep Neural Networks" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs["groups"] = 32
+ kwargs["width_per_group"] = 8
+ return _resnet("resnext101_32x8d", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
+
+
+def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""Wide ResNet-50-2 model from
+ `"Wide Residual Networks" `_.
+
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs["width_per_group"] = 64 * 2
+ return _resnet("wide_resnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
+
+
+def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""Wide ResNet-101-2 model from
+ `"Wide Residual Networks" `_.
+
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs["width_per_group"] = 64 * 2
+ return _resnet("wide_resnet101_2", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
+
+
+func_dict = {"resnet18": (resnet18, 512), "resnet50": (resnet50, 2048)}
diff --git a/Deep3DFaceRecon_pytorch/models/template_model.py b/Deep3DFaceRecon_pytorch/models/template_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..650d38cdc1bf61ed481c9d25fb709dc46483f757
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/models/template_model.py
@@ -0,0 +1,105 @@
+"""Model class template
+
+This module provides a template for users to implement custom models.
+You can specify '--model template' to use this model.
+The class name should be consistent with both the filename and its model option.
+The filename should be _dataset.py
+The class name should be Dataset.py
+It implements a simple image-to-image translation baseline based on regression loss.
+Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss:
+ min_ ||netG(data_A) - data_B||_1
+You need to implement the following functions:
+ : Add model-specific options and rewrite default values for existing options.
+ <__init__>: Initialize this model class.
+ : Unpack input data and perform data pre-processing.
+ : Run forward pass. This will be called by both and .
+ : Update network weights; it will be called in every training iteration.
+"""
+import numpy as np
+import torch
+
+from . import networks
+from .base_model import BaseModel
+
+
+class TemplateModel(BaseModel):
+ @staticmethod
+ def modify_commandline_options(parser, is_train=True):
+ """Add new model-specific options and rewrite default values for existing options.
+
+ Parameters:
+ parser -- the option parser
+ is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+ """
+ parser.set_defaults(
+ dataset_mode="aligned"
+ ) # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset.
+ if is_train:
+ parser.add_argument(
+ "--lambda_regression", type=float, default=1.0, help="weight for the regression loss"
+ ) # You can define new arguments for this model.
+
+ return parser
+
+ def __init__(self, opt):
+ """Initialize this model class.
+
+ Parameters:
+ opt -- training/test options
+
+ A few things can be done here.
+ - (required) call the initialization function of BaseModel
+ - define loss function, visualization images, model names, and optimizers
+ """
+ BaseModel.__init__(self, opt) # call the initialization method of BaseModel
+ # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.
+ self.loss_names = ["loss_G"]
+ # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.
+ self.visual_names = ["data_A", "data_B", "output"]
+ # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks.
+ # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them.
+ self.model_names = ["G"]
+ # define networks; you can use opt.isTrain to specify different behaviors for training and test.
+ self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids)
+ if self.isTrain: # only defined during training time
+ # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss.
+ # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device)
+ self.criterionLoss = torch.nn.L1Loss()
+ # define and initialize optimizers. You can define one optimizer for each network.
+ # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
+ self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
+ self.optimizers = [self.optimizer]
+
+ # Our program will automatically call to define schedulers, load networks, and print networks
+
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+
+ Parameters:
+ input: a dictionary that contains the data itself and its metadata information.
+ """
+ AtoB = self.opt.direction == "AtoB" # use to swap data_A and data_B
+ self.data_A = input["A" if AtoB else "B"].to(self.device) # get image data A
+ self.data_B = input["B" if AtoB else "A"].to(self.device) # get image data B
+ self.image_paths = input["A_paths" if AtoB else "B_paths"] # get image paths
+
+ def forward(self):
+ """Run forward pass. This will be called by both functions and ."""
+ self.output = self.netG(self.data_A) # generate output image given the input data_A
+
+ def backward(self):
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
+ # caculate the intermediate results if necessary; here self.output has been computed during function
+ # calculate loss given the input and intermediate results
+ self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression
+ self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G
+
+ def optimize_parameters(self):
+ """Update network weights; it will be called in every training iteration."""
+ self.forward() # first call forward to calculate intermediate results
+ self.optimizer.zero_grad() # clear network G's existing gradients
+ self.backward() # calculate gradients for network G
+ self.optimizer.step() # update gradients for network G
diff --git a/Deep3DFaceRecon_pytorch/options/__init__.py b/Deep3DFaceRecon_pytorch/options/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7eedebe54aa70169fd25951b3034d819e396c90
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/options/__init__.py
@@ -0,0 +1 @@
+"""This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
diff --git a/Deep3DFaceRecon_pytorch/options/base_options.py b/Deep3DFaceRecon_pytorch/options/base_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e8a723d26ef966ae097da3e53a2a27616258425
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/options/base_options.py
@@ -0,0 +1,203 @@
+"""This script contains base options for Deep3DFaceRecon_pytorch
+"""
+import argparse
+import os
+
+import data
+import numpy as np
+import torch
+from util import util
+
+import models
+
+
+class BaseOptions:
+ """This class defines options used during both training and test time.
+
+ It also implements several helper functions such as parsing, printing, and saving the options.
+ It also gathers additional options defined in functions in both dataset class and model class.
+ """
+
+ def __init__(self, cmd_line=None):
+ """Reset the class; indicates the class hasn't been initailized"""
+ self.initialized = False
+ self.cmd_line = None
+ if cmd_line is not None:
+ self.cmd_line = cmd_line.split()
+
+ def initialize(self, parser):
+ """Define the common options that are used in both training and test."""
+ # basic parameters
+ parser.add_argument(
+ "--name",
+ type=str,
+ default="face_recon",
+ help="name of the experiment. It decides where to store samples and models",
+ )
+ parser.add_argument("--gpu_ids", type=str, default="0", help="gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU")
+ parser.add_argument("--checkpoints_dir", type=str, default="./checkpoints", help="models are saved here")
+ parser.add_argument("--vis_batch_nums", type=float, default=1, help="batch nums of images for visulization")
+ parser.add_argument(
+ "--eval_batch_nums", type=float, default=float("inf"), help="batch nums of images for evaluation"
+ )
+ parser.add_argument(
+ "--use_ddp",
+ type=util.str2bool,
+ nargs="?",
+ const=True,
+ default=True,
+ help="whether use distributed data parallel",
+ )
+ parser.add_argument("--ddp_port", type=str, default="12355", help="ddp port")
+ parser.add_argument(
+ "--display_per_batch",
+ type=util.str2bool,
+ nargs="?",
+ const=True,
+ default=True,
+ help="whether use batch to show losses",
+ )
+ parser.add_argument(
+ "--add_image",
+ type=util.str2bool,
+ nargs="?",
+ const=True,
+ default=True,
+ help="whether add image to tensorboard",
+ )
+ parser.add_argument("--world_size", type=int, default=1, help="batch nums of images for evaluation")
+
+ # model parameters
+ parser.add_argument("--model", type=str, default="facerecon", help="chooses which model to use.")
+
+ # additional parameters
+ parser.add_argument(
+ "--epoch", type=str, default="latest", help="which epoch to load? set to latest to use latest cached model"
+ )
+ parser.add_argument("--verbose", action="store_true", help="if specified, print more debugging information")
+ parser.add_argument(
+ "--suffix",
+ default="",
+ type=str,
+ help="customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}",
+ )
+
+ self.initialized = True
+ return parser
+
+ def gather_options(self):
+ """Initialize our parser with basic options(only once).
+ Add additional model-specific and dataset-specific options.
+ These options are defined in the function
+ in model and dataset classes.
+ """
+ if not self.initialized: # check if it has been initialized
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser = self.initialize(parser)
+
+ # get the basic options
+ if self.cmd_line is None:
+ opt, _ = parser.parse_known_args()
+ else:
+ opt, _ = parser.parse_known_args(self.cmd_line)
+
+ # set cuda visible devices
+ os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids
+
+ # modify model-related parser options
+ model_name = opt.model
+ model_option_setter = models.get_option_setter(model_name)
+ parser = model_option_setter(parser, self.isTrain)
+ if self.cmd_line is None:
+ opt, _ = parser.parse_known_args() # parse again with new defaults
+ else:
+ opt, _ = parser.parse_known_args(self.cmd_line) # parse again with new defaults
+
+ # modify dataset-related parser options
+ if opt.dataset_mode:
+ dataset_name = opt.dataset_mode
+ dataset_option_setter = data.get_option_setter(dataset_name)
+ parser = dataset_option_setter(parser, self.isTrain)
+
+ # save and return the parser
+ self.parser = parser
+ if self.cmd_line is None:
+ return parser.parse_args()
+ else:
+ return parser.parse_args(self.cmd_line)
+
+ def print_options(self, opt):
+ """Print and save options
+
+ It will print both current options and default values(if different).
+ It will save options into a text file / [checkpoints_dir] / opt.txt
+ """
+ message = ""
+ message += "----------------- Options ---------------\n"
+ for k, v in sorted(vars(opt).items()):
+ comment = ""
+ default = self.parser.get_default(k)
+ if v != default:
+ comment = "\t[default: %s]" % str(default)
+ message += "{:>25}: {:<30}{}\n".format(str(k), str(v), comment)
+ message += "----------------- End -------------------"
+ print(message)
+
+ # save to the disk
+ expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
+ util.mkdirs(expr_dir)
+ file_name = os.path.join(expr_dir, "{}_opt.txt".format(opt.phase))
+ try:
+ with open(file_name, "wt") as opt_file:
+ opt_file.write(message)
+ opt_file.write("\n")
+ except PermissionError as error:
+ print("permission error {}".format(error))
+ pass
+
+ def parse(self):
+ """Parse our options, create checkpoints directory suffix, and set up gpu device."""
+ opt = self.gather_options()
+ opt.isTrain = self.isTrain # train or test
+
+ # process opt.suffix
+ if opt.suffix:
+ suffix = ("_" + opt.suffix.format(**vars(opt))) if opt.suffix != "" else ""
+ opt.name = opt.name + suffix
+
+ # set gpu ids
+ str_ids = opt.gpu_ids.split(",")
+ gpu_ids = []
+ for str_id in str_ids:
+ id = int(str_id)
+ if id >= 0:
+ gpu_ids.append(id)
+ opt.world_size = len(gpu_ids)
+ # if len(opt.gpu_ids) > 0:
+ # torch.cuda.set_device(gpu_ids[0])
+ if opt.world_size == 1:
+ opt.use_ddp = False
+
+ if opt.phase != "test":
+ # set continue_train automatically
+ if opt.pretrained_name is None:
+ model_dir = os.path.join(opt.checkpoints_dir, opt.name)
+ else:
+ model_dir = os.path.join(opt.checkpoints_dir, opt.pretrained_name)
+ if os.path.isdir(model_dir):
+ model_pths = [i for i in os.listdir(model_dir) if i.endswith("pth")]
+ if os.path.isdir(model_dir) and len(model_pths) != 0:
+ opt.continue_train = True
+
+ # update the latest epoch count
+ if opt.continue_train:
+ if opt.epoch == "latest":
+ epoch_counts = [int(i.split(".")[0].split("_")[-1]) for i in model_pths if "latest" not in i]
+ if len(epoch_counts) != 0:
+ opt.epoch_count = max(epoch_counts) + 1
+ else:
+ opt.epoch_count = int(opt.epoch) + 1
+
+ self.print_options(opt)
+ self.opt = opt
+ return self.opt
diff --git a/Deep3DFaceRecon_pytorch/options/test_options.py b/Deep3DFaceRecon_pytorch/options/test_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4aca5ef369ddd427dd87f31af03f31256d46176
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/options/test_options.py
@@ -0,0 +1,22 @@
+"""This script contains the test options for Deep3DFaceRecon_pytorch
+"""
+from .base_options import BaseOptions
+
+
+class TestOptions(BaseOptions):
+ """This class includes test options.
+
+ It also includes shared options defined in BaseOptions.
+ """
+
+ def initialize(self, parser):
+ parser = BaseOptions.initialize(self, parser) # define shared options
+ parser.add_argument("--phase", type=str, default="test", help="train, val, test, etc")
+ parser.add_argument(
+ "--dataset_mode", type=str, default=None, help="chooses how datasets are loaded. [None | flist]"
+ )
+ parser.add_argument("--img_folder", type=str, default="examples", help="folder for test images.")
+
+ # Dropout and Batchnorm has different behavior during training and test.
+ self.isTrain = False
+ return parser
diff --git a/Deep3DFaceRecon_pytorch/options/train_options.py b/Deep3DFaceRecon_pytorch/options/train_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e02ee3e87b49cae8f7e660d4b891ef062f06d97
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/options/train_options.py
@@ -0,0 +1,90 @@
+"""This script contains the training options for Deep3DFaceRecon_pytorch
+"""
+from util import util
+
+from .base_options import BaseOptions
+
+
+class TrainOptions(BaseOptions):
+ """This class includes training options.
+
+ It also includes shared options defined in BaseOptions.
+ """
+
+ def initialize(self, parser):
+ parser = BaseOptions.initialize(self, parser)
+ # dataset parameters
+ # for train
+ parser.add_argument("--data_root", type=str, default="./", help="dataset root")
+ parser.add_argument(
+ "--flist", type=str, default="datalist/train/masks.txt", help="list of mask names of training set"
+ )
+ parser.add_argument("--batch_size", type=int, default=32)
+ parser.add_argument(
+ "--dataset_mode", type=str, default="flist", help="chooses how datasets are loaded. [None | flist]"
+ )
+ parser.add_argument(
+ "--serial_batches",
+ action="store_true",
+ help="if true, takes images in order to make batches, otherwise takes them randomly",
+ )
+ parser.add_argument("--num_threads", default=4, type=int, help="# threads for loading data")
+ parser.add_argument(
+ "--max_dataset_size",
+ type=int,
+ default=float("inf"),
+ help="Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.",
+ )
+ parser.add_argument(
+ "--preprocess",
+ type=str,
+ default="shift_scale_rot_flip",
+ help="scaling and cropping of images at load time [shift_scale_rot_flip | shift_scale | shift | shift_rot_flip ]",
+ )
+ parser.add_argument(
+ "--use_aug", type=util.str2bool, nargs="?", const=True, default=True, help="whether use data augmentation"
+ )
+
+ # for val
+ parser.add_argument(
+ "--flist_val", type=str, default="datalist/val/masks.txt", help="list of mask names of val set"
+ )
+ parser.add_argument("--batch_size_val", type=int, default=32)
+
+ # visualization parameters
+ parser.add_argument(
+ "--display_freq", type=int, default=1000, help="frequency of showing training results on screen"
+ )
+ parser.add_argument(
+ "--print_freq", type=int, default=100, help="frequency of showing training results on console"
+ )
+
+ # network saving and loading parameters
+ parser.add_argument("--save_latest_freq", type=int, default=5000, help="frequency of saving the latest results")
+ parser.add_argument(
+ "--save_epoch_freq", type=int, default=1, help="frequency of saving checkpoints at the end of epochs"
+ )
+ parser.add_argument("--evaluation_freq", type=int, default=5000, help="evaluation freq")
+ parser.add_argument("--save_by_iter", action="store_true", help="whether saves model by iteration")
+ parser.add_argument("--continue_train", action="store_true", help="continue training: load the latest model")
+ parser.add_argument(
+ "--epoch_count",
+ type=int,
+ default=1,
+ help="the starting epoch count, we save the model by , +, ...",
+ )
+ parser.add_argument("--phase", type=str, default="train", help="train, val, test, etc")
+ parser.add_argument("--pretrained_name", type=str, default=None, help="resume training from another checkpoint")
+
+ # training parameters
+ parser.add_argument("--n_epochs", type=int, default=20, help="number of epochs with the initial learning rate")
+ parser.add_argument("--lr", type=float, default=0.0001, help="initial learning rate for adam")
+ parser.add_argument(
+ "--lr_policy", type=str, default="step", help="learning rate policy. [linear | step | plateau | cosine]"
+ )
+ parser.add_argument(
+ "--lr_decay_epochs", type=int, default=10, help="multiply by a gamma every lr_decay_epochs epoches"
+ )
+
+ self.isTrain = True
+ return parser
diff --git a/Deep3DFaceRecon_pytorch/test.py b/Deep3DFaceRecon_pytorch/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ddae5b61c68fa6dd607480a0fcc7f3f88199d5d
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/test.py
@@ -0,0 +1,94 @@
+"""This script is the test script for Deep3DFaceRecon_pytorch
+"""
+import os
+
+import numpy as np
+import torch
+from data import create_dataset
+from data.flist_dataset import default_flist_reader
+from options.test_options import TestOptions
+from PIL import Image
+from scipy.io import loadmat
+from scipy.io import savemat
+from util.load_mats import load_lm3d
+from util.preprocess import align_img
+from util.visualizer import MyVisualizer
+
+from models import create_model
+
+
+def get_data_path(root="examples"):
+
+ im_path = [os.path.join(root, i) for i in sorted(os.listdir(root)) if i.endswith("png") or i.endswith("jpg")]
+ lm_path = [i.replace("png", "txt").replace("jpg", "txt") for i in im_path]
+ lm_path = [
+ os.path.join(i.replace(i.split(os.path.sep)[-1], ""), "detections", i.split(os.path.sep)[-1]) for i in lm_path
+ ]
+
+ return im_path, lm_path
+
+
+def read_data(im_path, lm_path, lm3d_std, to_tensor=True):
+ # to RGB
+ im = Image.open(im_path).convert("RGB")
+ W, H = im.size
+ lm = np.loadtxt(lm_path).astype(np.float32)
+ lm = lm.reshape([-1, 2])
+ lm[:, -1] = H - 1 - lm[:, -1]
+ _, im, lm, _ = align_img(im, lm, lm3d_std)
+ if to_tensor:
+ im = torch.tensor(np.array(im) / 255.0, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)
+ lm = torch.tensor(lm).unsqueeze(0)
+ return im, lm
+
+
+def main(rank, opt, name="examples"):
+ device = torch.device(rank)
+ torch.cuda.set_device(device)
+ model = create_model(opt)
+ model.setup(opt)
+ model.device = device
+ model.parallelize()
+ model.eval()
+ visualizer = MyVisualizer(opt)
+
+ im_path, lm_path = get_data_path(name)
+ lm3d_std = load_lm3d(opt.bfm_folder)
+
+ for i in range(len(im_path)):
+ print(i, im_path[i])
+ img_name = im_path[i].split(os.path.sep)[-1].replace(".png", "").replace(".jpg", "")
+ if not os.path.isfile(lm_path[i]):
+ print("%s is not found !!!" % lm_path[i])
+ continue
+ im_tensor, lm_tensor = read_data(im_path[i], lm_path[i], lm3d_std)
+ data = {"imgs": im_tensor, "lms": lm_tensor}
+ model.set_input(data) # unpack data from data loader
+ model.test() # run inference
+ visuals = model.get_current_visuals() # get image results
+ visualizer.display_current_results(
+ visuals,
+ 0,
+ opt.epoch,
+ dataset=name.split(os.path.sep)[-1],
+ save_results=True,
+ count=i,
+ name=img_name,
+ add_image=False,
+ )
+
+ model.save_mesh(
+ os.path.join(
+ visualizer.img_dir, name.split(os.path.sep)[-1], "epoch_%s_%06d" % (opt.epoch, 0), img_name + ".obj"
+ )
+ ) # save reconstruction meshes
+ model.save_coeff(
+ os.path.join(
+ visualizer.img_dir, name.split(os.path.sep)[-1], "epoch_%s_%06d" % (opt.epoch, 0), img_name + ".mat"
+ )
+ ) # save predicted coefficients
+
+
+if __name__ == "__main__":
+ opt = TestOptions().parse() # get test options
+ main(0, opt, opt.img_folder)
diff --git a/Deep3DFaceRecon_pytorch/train.py b/Deep3DFaceRecon_pytorch/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fe293ec7561eefc6780a137064f79379c8eb07f
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/train.py
@@ -0,0 +1,194 @@
+"""This script is the training script for Deep3DFaceRecon_pytorch
+"""
+import os
+import time
+
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from data import create_dataset
+from options.train_options import TrainOptions
+from util.util import genvalconf
+from util.visualizer import MyVisualizer
+
+from models import create_model
+
+
+def setup(rank, world_size, port):
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = port
+
+ # initialize the process group
+ dist.init_process_group("gloo", rank=rank, world_size=world_size)
+
+
+def cleanup():
+ dist.destroy_process_group()
+
+
+def main(rank, world_size, train_opt):
+ val_opt = genvalconf(train_opt, isTrain=False)
+
+ device = torch.device(rank)
+ torch.cuda.set_device(device)
+ use_ddp = train_opt.use_ddp
+
+ if use_ddp:
+ setup(rank, world_size, train_opt.ddp_port)
+
+ train_dataset, val_dataset = create_dataset(train_opt, rank=rank), create_dataset(val_opt, rank=rank)
+ train_dataset_batches, val_dataset_batches = (
+ len(train_dataset) // train_opt.batch_size,
+ len(val_dataset) // val_opt.batch_size,
+ )
+
+ model = create_model(train_opt) # create a model given train_opt.model and other options
+ model.setup(train_opt)
+ model.device = device
+ model.parallelize()
+
+ if rank == 0:
+ print(
+ "The batch number of training images = %d\n, \
+ the batch number of validation images = %d"
+ % (train_dataset_batches, val_dataset_batches)
+ )
+ model.print_networks(train_opt.verbose)
+ visualizer = MyVisualizer(train_opt) # create a visualizer that display/save images and plots
+
+ total_iters = train_dataset_batches * (train_opt.epoch_count - 1) # the total number of training iterations
+ t_data = 0
+ t_val = 0
+ optimize_time = 0.1
+ batch_size = 1 if train_opt.display_per_batch else train_opt.batch_size
+
+ if use_ddp:
+ dist.barrier()
+
+ times = []
+ for epoch in range(
+ train_opt.epoch_count, train_opt.n_epochs + 1
+ ): # outer loop for different epochs; we save the model by , +
+ epoch_start_time = time.time() # timer for entire epoch
+ iter_data_time = time.time() # timer for train_data loading per iteration
+ epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch
+
+ train_dataset.set_epoch(epoch)
+ for i, train_data in enumerate(train_dataset): # inner loop within one epoch
+ iter_start_time = time.time() # timer for computation per iteration
+ if total_iters % train_opt.print_freq == 0:
+ t_data = iter_start_time - iter_data_time
+ total_iters += batch_size
+ epoch_iter += batch_size
+
+ torch.cuda.synchronize()
+ optimize_start_time = time.time()
+
+ model.set_input(train_data) # unpack train_data from dataset and apply preprocessing
+ model.optimize_parameters() # calculate loss functions, get gradients, update network weights
+
+ torch.cuda.synchronize()
+ optimize_time = (time.time() - optimize_start_time) / batch_size * 0.005 + 0.995 * optimize_time
+
+ if use_ddp:
+ dist.barrier()
+
+ if rank == 0 and (
+ total_iters == batch_size or total_iters % train_opt.display_freq == 0
+ ): # display images on visdom and save images to a HTML file
+ model.compute_visuals()
+ visualizer.display_current_results(
+ model.get_current_visuals(), total_iters, epoch, save_results=True, add_image=train_opt.add_image
+ )
+ # (total_iters == batch_size or total_iters % train_opt.evaluation_freq == 0)
+
+ if rank == 0 and (
+ total_iters == batch_size or total_iters % train_opt.print_freq == 0
+ ): # print training losses and save logging information to the disk
+ losses = model.get_current_losses()
+ visualizer.print_current_losses(epoch, epoch_iter, losses, optimize_time, t_data)
+ visualizer.plot_current_losses(total_iters, losses)
+
+ if total_iters == batch_size or total_iters % train_opt.evaluation_freq == 0:
+ with torch.no_grad():
+ torch.cuda.synchronize()
+ val_start_time = time.time()
+ losses_avg = {}
+ model.eval()
+ for j, val_data in enumerate(val_dataset):
+ model.set_input(val_data)
+ model.optimize_parameters(isTrain=False)
+ if rank == 0 and j < train_opt.vis_batch_nums:
+ model.compute_visuals()
+ visualizer.display_current_results(
+ model.get_current_visuals(),
+ total_iters,
+ epoch,
+ dataset="val",
+ save_results=True,
+ count=j * val_opt.batch_size,
+ add_image=train_opt.add_image,
+ )
+
+ if j < train_opt.eval_batch_nums:
+ losses = model.get_current_losses()
+ for key, value in losses.items():
+ losses_avg[key] = losses_avg.get(key, 0) + value
+
+ for key, value in losses_avg.items():
+ losses_avg[key] = value / min(train_opt.eval_batch_nums, val_dataset_batches)
+
+ torch.cuda.synchronize()
+ eval_time = time.time() - val_start_time
+
+ if rank == 0:
+ visualizer.print_current_losses(
+ epoch, epoch_iter, losses_avg, eval_time, t_data, dataset="val"
+ ) # visualize training results
+ visualizer.plot_current_losses(total_iters, losses_avg, dataset="val")
+ model.train()
+
+ if use_ddp:
+ dist.barrier()
+
+ if rank == 0 and (
+ total_iters == batch_size or total_iters % train_opt.save_latest_freq == 0
+ ): # cache our latest model every iterations
+ print("saving the latest model (epoch %d, total_iters %d)" % (epoch, total_iters))
+ print(train_opt.name) # it's useful to occasionally show the experiment name on console
+ save_suffix = "iter_%d" % total_iters if train_opt.save_by_iter else "latest"
+ model.save_networks(save_suffix)
+
+ if use_ddp:
+ dist.barrier()
+
+ iter_data_time = time.time()
+
+ print(
+ "End of epoch %d / %d \t Time Taken: %d sec" % (epoch, train_opt.n_epochs, time.time() - epoch_start_time)
+ )
+ model.update_learning_rate() # update learning rates at the end of every epoch.
+
+ if rank == 0 and epoch % train_opt.save_epoch_freq == 0: # cache our model every epochs
+ print("saving the model at the end of epoch %d, iters %d" % (epoch, total_iters))
+ model.save_networks("latest")
+ model.save_networks(epoch)
+
+ if use_ddp:
+ dist.barrier()
+
+
+if __name__ == "__main__":
+
+ import warnings
+
+ warnings.filterwarnings("ignore")
+
+ train_opt = TrainOptions().parse() # get training options
+ world_size = train_opt.world_size
+
+ if train_opt.use_ddp:
+ mp.spawn(main, args=(world_size, train_opt), nprocs=world_size, join=True)
+ else:
+ main(0, world_size, train_opt)
diff --git a/Deep3DFaceRecon_pytorch/util/BBRegressorParam_r.mat b/Deep3DFaceRecon_pytorch/util/BBRegressorParam_r.mat
new file mode 100644
index 0000000000000000000000000000000000000000..1430a94ed2ab570a09f9d980d3585e8aaa933084
Binary files /dev/null and b/Deep3DFaceRecon_pytorch/util/BBRegressorParam_r.mat differ
diff --git a/Deep3DFaceRecon_pytorch/util/__init__.py b/Deep3DFaceRecon_pytorch/util/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..41b23b46599a06463b78a40896b038d010e5f847
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/util/__init__.py
@@ -0,0 +1,2 @@
+"""This package includes a miscellaneous collection of useful helper functions."""
+from Deep3DFaceRecon_pytorch.util import *
diff --git a/Deep3DFaceRecon_pytorch/util/__pycache__/__init__.cpython-310.pyc b/Deep3DFaceRecon_pytorch/util/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e4c3ee2c4aa63a86f6ec50191fa759cbf062848e
Binary files /dev/null and b/Deep3DFaceRecon_pytorch/util/__pycache__/__init__.cpython-310.pyc differ
diff --git a/Deep3DFaceRecon_pytorch/util/__pycache__/load_mats.cpython-310.pyc b/Deep3DFaceRecon_pytorch/util/__pycache__/load_mats.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..874f14f69e5e7dd97c8820830dab5cc411e301e1
Binary files /dev/null and b/Deep3DFaceRecon_pytorch/util/__pycache__/load_mats.cpython-310.pyc differ
diff --git a/Deep3DFaceRecon_pytorch/util/detect_lm68.py b/Deep3DFaceRecon_pytorch/util/detect_lm68.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d6b5c50d485617387b7259da7fd5be8a4396937
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/util/detect_lm68.py
@@ -0,0 +1,108 @@
+import os
+from shutil import move
+
+import cv2
+import numpy as np
+import tensorflow as tf
+from scipy.io import loadmat
+from util.preprocess import align_for_lm
+
+mean_face = np.loadtxt("util/test_mean_face.txt")
+mean_face = mean_face.reshape([68, 2])
+
+
+def save_label(labels, save_path):
+ np.savetxt(save_path, labels)
+
+
+def draw_landmarks(img, landmark, save_name):
+ landmark = landmark
+ lm_img = np.zeros([img.shape[0], img.shape[1], 3])
+ lm_img[:] = img.astype(np.float32)
+ landmark = np.round(landmark).astype(np.int32)
+
+ for i in range(len(landmark)):
+ for j in range(-1, 1):
+ for k in range(-1, 1):
+ if (
+ img.shape[0] - 1 - landmark[i, 1] + j > 0
+ and img.shape[0] - 1 - landmark[i, 1] + j < img.shape[0]
+ and landmark[i, 0] + k > 0
+ and landmark[i, 0] + k < img.shape[1]
+ ):
+ lm_img[img.shape[0] - 1 - landmark[i, 1] + j, landmark[i, 0] + k, :] = np.array([0, 0, 255])
+ lm_img = lm_img.astype(np.uint8)
+
+ cv2.imwrite(save_name, lm_img)
+
+
+def load_data(img_name, txt_name):
+ return cv2.imread(img_name), np.loadtxt(txt_name)
+
+
+# create tensorflow graph for landmark detector
+def load_lm_graph(graph_filename):
+ with tf.gfile.GFile(graph_filename, "rb") as f:
+ graph_def = tf.GraphDef()
+ graph_def.ParseFromString(f.read())
+
+ with tf.Graph().as_default() as graph:
+ tf.import_graph_def(graph_def, name="net")
+ img_224 = graph.get_tensor_by_name("net/input_imgs:0")
+ output_lm = graph.get_tensor_by_name("net/lm:0")
+ lm_sess = tf.Session(graph=graph)
+
+ return lm_sess, img_224, output_lm
+
+
+# landmark detection
+def detect_68p(img_path, sess, input_op, output_op):
+ print("detecting landmarks......")
+ names = [i for i in sorted(os.listdir(img_path)) if "jpg" in i or "png" in i or "jpeg" in i or "PNG" in i]
+ vis_path = os.path.join(img_path, "vis")
+ remove_path = os.path.join(img_path, "remove")
+ save_path = os.path.join(img_path, "landmarks")
+ if not os.path.isdir(vis_path):
+ os.makedirs(vis_path)
+ if not os.path.isdir(remove_path):
+ os.makedirs(remove_path)
+ if not os.path.isdir(save_path):
+ os.makedirs(save_path)
+
+ for i in range(0, len(names)):
+ name = names[i]
+ print("%05d" % (i), " ", name)
+ full_image_name = os.path.join(img_path, name)
+ txt_name = ".".join(name.split(".")[:-1]) + ".txt"
+ full_txt_name = os.path.join(img_path, "detections", txt_name) # 5 facial landmark path for each image
+
+ # if an image does not have detected 5 facial landmarks, remove it from the training list
+ if not os.path.isfile(full_txt_name):
+ move(full_image_name, os.path.join(remove_path, name))
+ continue
+
+ # load data
+ img, five_points = load_data(full_image_name, full_txt_name)
+ input_img, scale, bbox = align_for_lm(img, five_points) # align for 68 landmark detection
+
+ # if the alignment fails, remove corresponding image from the training list
+ if scale == 0:
+ move(full_txt_name, os.path.join(remove_path, txt_name))
+ move(full_image_name, os.path.join(remove_path, name))
+ continue
+
+ # detect landmarks
+ input_img = np.reshape(input_img, [1, 224, 224, 3]).astype(np.float32)
+ landmark = sess.run(output_op, feed_dict={input_op: input_img})
+
+ # transform back to original image coordinate
+ landmark = landmark.reshape([68, 2]) + mean_face
+ landmark[:, 1] = 223 - landmark[:, 1]
+ landmark = landmark / scale
+ landmark[:, 0] = landmark[:, 0] + bbox[0]
+ landmark[:, 1] = landmark[:, 1] + bbox[1]
+ landmark[:, 1] = img.shape[0] - 1 - landmark[:, 1]
+
+ if i % 100 == 0:
+ draw_landmarks(img, landmark, os.path.join(vis_path, name))
+ save_label(landmark, os.path.join(save_path, txt_name))
diff --git a/Deep3DFaceRecon_pytorch/util/generate_list.py b/Deep3DFaceRecon_pytorch/util/generate_list.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb6a19ccb98aa3542c0e180d675b6561a2a3844a
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/util/generate_list.py
@@ -0,0 +1,34 @@
+"""This script is to generate training list files for Deep3DFaceRecon_pytorch
+"""
+import os
+
+# save path to training data
+def write_list(lms_list, imgs_list, msks_list, mode="train", save_folder="datalist", save_name=""):
+ save_path = os.path.join(save_folder, mode)
+ if not os.path.isdir(save_path):
+ os.makedirs(save_path)
+ with open(os.path.join(save_path, save_name + "landmarks.txt"), "w") as fd:
+ fd.writelines([i + "\n" for i in lms_list])
+
+ with open(os.path.join(save_path, save_name + "images.txt"), "w") as fd:
+ fd.writelines([i + "\n" for i in imgs_list])
+
+ with open(os.path.join(save_path, save_name + "masks.txt"), "w") as fd:
+ fd.writelines([i + "\n" for i in msks_list])
+
+
+# check if the path is valid
+def check_list(rlms_list, rimgs_list, rmsks_list):
+ lms_list, imgs_list, msks_list = [], [], []
+ for i in range(len(rlms_list)):
+ flag = "false"
+ lm_path = rlms_list[i]
+ im_path = rimgs_list[i]
+ msk_path = rmsks_list[i]
+ if os.path.isfile(lm_path) and os.path.isfile(im_path) and os.path.isfile(msk_path):
+ flag = "true"
+ lms_list.append(rlms_list[i])
+ imgs_list.append(rimgs_list[i])
+ msks_list.append(rmsks_list[i])
+ print(i, rlms_list[i], flag)
+ return lms_list, imgs_list, msks_list
diff --git a/Deep3DFaceRecon_pytorch/util/html.py b/Deep3DFaceRecon_pytorch/util/html.py
new file mode 100644
index 0000000000000000000000000000000000000000..5570e71e887cd23d09913091da7bfe6c3894b833
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/util/html.py
@@ -0,0 +1,95 @@
+import os
+
+import dominate
+from dominate.tags import a
+from dominate.tags import br
+from dominate.tags import h3
+from dominate.tags import img
+from dominate.tags import meta
+from dominate.tags import p
+from dominate.tags import table
+from dominate.tags import td
+from dominate.tags import tr
+
+
+class HTML:
+ """This HTML class allows us to save images and write texts into a single HTML file.
+
+ It consists of functions such as (add a text header to the HTML file),
+ (add a row of images to the HTML file), and (save the HTML to the disk).
+ It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
+ """
+
+ def __init__(self, web_dir, title, refresh=0):
+ """Initialize the HTML classes
+
+ Parameters:
+ web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0:
+ with self.doc.head:
+ meta(http_equiv="refresh", content=str(refresh))
+
+ def get_image_dir(self):
+ """Return the directory that stores images"""
+ return self.img_dir
+
+ def add_header(self, text):
+ """Insert a header to the HTML file
+
+ Parameters:
+ text (str) -- the header text
+ """
+ with self.doc:
+ h3(text)
+
+ def add_images(self, ims, txts, links, width=400):
+ """add images to the HTML file
+
+ Parameters:
+ ims (str list) -- a list of image paths
+ txts (str list) -- a list of image names shown on the website
+ links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
+ """
+ self.t = table(border=1, style="table-layout: fixed;") # Insert a table
+ self.doc.add(self.t)
+ with self.t:
+ with tr():
+ for im, txt, link in zip(ims, txts, links):
+ with td(style="word-wrap: break-word;", halign="center", valign="top"):
+ with p():
+ with a(href=os.path.join("images", link)):
+ img(style="width:%dpx" % width, src=os.path.join("images", im))
+ br()
+ p(txt)
+
+ def save(self):
+ """save the current content to the HMTL file"""
+ html_file = "%s/index.html" % self.web_dir
+ f = open(html_file, "wt")
+ f.write(self.doc.render())
+ f.close()
+
+
+if __name__ == "__main__": # we show an example usage here.
+ html = HTML("web/", "test_html")
+ html.add_header("hello world")
+
+ ims, txts, links = [], [], []
+ for n in range(4):
+ ims.append("image_%d.png" % n)
+ txts.append("text_%d" % n)
+ links.append("image_%d.png" % n)
+ html.add_images(ims, txts, links)
+ html.save()
diff --git a/Deep3DFaceRecon_pytorch/util/load_mats.py b/Deep3DFaceRecon_pytorch/util/load_mats.py
new file mode 100644
index 0000000000000000000000000000000000000000..45c7080907d38119f46020af461c07f91bfbbdbf
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/util/load_mats.py
@@ -0,0 +1,138 @@
+"""This script is to load 3D face model for Deep3DFaceRecon_pytorch
+"""
+import os.path as osp
+from array import array
+
+import numpy as np
+from PIL import Image
+from scipy.io import loadmat
+from scipy.io import savemat
+
+# load expression basis
+def LoadExpBasis(bfm_folder="BFM"):
+ n_vertex = 53215
+ Expbin = open(osp.join(bfm_folder, "Exp_Pca.bin"), "rb")
+ exp_dim = array("i")
+ exp_dim.fromfile(Expbin, 1)
+ expMU = array("f")
+ expPC = array("f")
+ expMU.fromfile(Expbin, 3 * n_vertex)
+ expPC.fromfile(Expbin, 3 * exp_dim[0] * n_vertex)
+ Expbin.close()
+
+ expPC = np.array(expPC)
+ expPC = np.reshape(expPC, [exp_dim[0], -1])
+ expPC = np.transpose(expPC)
+
+ expEV = np.loadtxt(osp.join(bfm_folder, "std_exp.txt"))
+
+ return expPC, expEV
+
+
+# transfer original BFM09 to our face model
+def transferBFM09(bfm_folder="BFM"):
+ print("Transfer BFM09 to BFM_model_front......")
+ original_BFM = loadmat(osp.join(bfm_folder, "01_MorphableModel.mat"))
+ shapePC = original_BFM["shapePC"] # shape basis
+ shapeEV = original_BFM["shapeEV"] # corresponding eigen value
+ shapeMU = original_BFM["shapeMU"] # mean face
+ texPC = original_BFM["texPC"] # texture basis
+ texEV = original_BFM["texEV"] # eigen value
+ texMU = original_BFM["texMU"] # mean texture
+
+ expPC, expEV = LoadExpBasis()
+
+ # transfer BFM09 to our face model
+
+ idBase = shapePC * np.reshape(shapeEV, [-1, 199])
+ idBase = idBase / 1e5 # unify the scale to decimeter
+ idBase = idBase[:, :80] # use only first 80 basis
+
+ exBase = expPC * np.reshape(expEV, [-1, 79])
+ exBase = exBase / 1e5 # unify the scale to decimeter
+ exBase = exBase[:, :64] # use only first 64 basis
+
+ texBase = texPC * np.reshape(texEV, [-1, 199])
+ texBase = texBase[:, :80] # use only first 80 basis
+
+ # our face model is cropped along face landmarks and contains only 35709 vertex.
+ # original BFM09 contains 53490 vertex, and expression basis provided by Guo et al. contains 53215 vertex.
+ # thus we select corresponding vertex to get our face model.
+
+ index_exp = loadmat(osp.join(bfm_folder, "BFM_front_idx.mat"))
+ index_exp = index_exp["idx"].astype(np.int32) - 1 # starts from 0 (to 53215)
+
+ index_shape = loadmat(osp.join(bfm_folder, "BFM_exp_idx.mat"))
+ index_shape = index_shape["trimIndex"].astype(np.int32) - 1 # starts from 0 (to 53490)
+ index_shape = index_shape[index_exp]
+
+ idBase = np.reshape(idBase, [-1, 3, 80])
+ idBase = idBase[index_shape, :, :]
+ idBase = np.reshape(idBase, [-1, 80])
+
+ texBase = np.reshape(texBase, [-1, 3, 80])
+ texBase = texBase[index_shape, :, :]
+ texBase = np.reshape(texBase, [-1, 80])
+
+ exBase = np.reshape(exBase, [-1, 3, 64])
+ exBase = exBase[index_exp, :, :]
+ exBase = np.reshape(exBase, [-1, 64])
+
+ meanshape = np.reshape(shapeMU, [-1, 3]) / 1e5
+ meanshape = meanshape[index_shape, :]
+ meanshape = np.reshape(meanshape, [1, -1])
+
+ meantex = np.reshape(texMU, [-1, 3])
+ meantex = meantex[index_shape, :]
+ meantex = np.reshape(meantex, [1, -1])
+
+ # other info contains triangles, region used for computing photometric loss,
+ # region used for skin texture regularization, and 68 landmarks index etc.
+ other_info = loadmat(osp.join(bfm_folder, "facemodel_info.mat"))
+ frontmask2_idx = other_info["frontmask2_idx"]
+ skinmask = other_info["skinmask"]
+ keypoints = other_info["keypoints"]
+ point_buf = other_info["point_buf"]
+ tri = other_info["tri"]
+ tri_mask2 = other_info["tri_mask2"]
+
+ # save our face model
+ savemat(
+ osp.join(bfm_folder, "BFM_model_front.mat"),
+ {
+ "meanshape": meanshape,
+ "meantex": meantex,
+ "idBase": idBase,
+ "exBase": exBase,
+ "texBase": texBase,
+ "tri": tri,
+ "point_buf": point_buf,
+ "tri_mask2": tri_mask2,
+ "keypoints": keypoints,
+ "frontmask2_idx": frontmask2_idx,
+ "skinmask": skinmask,
+ },
+ )
+
+
+# load landmarks for standard face, which is used for image preprocessing
+def load_lm3d(bfm_folder):
+
+ Lm3D = loadmat(osp.join(bfm_folder, "similarity_Lm3D_all.mat"))
+ Lm3D = Lm3D["lm"]
+
+ # calculate 5 facial landmarks using 68 landmarks
+ lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1
+ Lm3D = np.stack(
+ [
+ Lm3D[lm_idx[0], :],
+ np.mean(Lm3D[lm_idx[[1, 2]], :], 0),
+ np.mean(Lm3D[lm_idx[[3, 4]], :], 0),
+ Lm3D[lm_idx[5], :],
+ Lm3D[lm_idx[6], :],
+ ],
+ axis=0,
+ )
+ Lm3D = Lm3D[[1, 2, 0, 3, 4], :]
+
+ return Lm3D
diff --git a/Deep3DFaceRecon_pytorch/util/nvdiffrast.py b/Deep3DFaceRecon_pytorch/util/nvdiffrast.py
new file mode 100644
index 0000000000000000000000000000000000000000..1db5799ef4e979b8a91281f527ae040c5c35e299
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/util/nvdiffrast.py
@@ -0,0 +1,91 @@
+"""This script is the differentiable renderer for Deep3DFaceRecon_pytorch
+ Attention, antialiasing step is missing in current version.
+"""
+from typing import List
+
+import kornia
+import numpy as np
+import torch
+import torch.nn.functional as F
+from kornia.geometry.camera import pixel2cam
+from scipy.io import loadmat
+from torch import nn
+
+import nvdiffrast.torch as dr
+
+
+def ndc_projection(x=0.1, n=1.0, f=50.0):
+ return np.array(
+ [[n / x, 0, 0, 0], [0, n / -x, 0, 0], [0, 0, -(f + n) / (f - n), -(2 * f * n) / (f - n)], [0, 0, -1, 0]]
+ ).astype(np.float32)
+
+
+class MeshRenderer(nn.Module):
+ def __init__(self, rasterize_fov, znear=0.1, zfar=10, rasterize_size=224, use_opengl=True):
+ super(MeshRenderer, self).__init__()
+
+ x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear
+ self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul(
+ torch.diag(torch.tensor([1.0, -1, -1, 1]))
+ )
+ self.rasterize_size = rasterize_size
+ self.use_opengl = use_opengl
+ self.ctx = None
+
+ def forward(self, vertex, tri, feat=None):
+ """
+ Return:
+ mask -- torch.tensor, size (B, 1, H, W)
+ depth -- torch.tensor, size (B, 1, H, W)
+ features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None
+
+ Parameters:
+ vertex -- torch.tensor, size (B, N, 3)
+ tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles
+ feat(optional) -- torch.tensor, size (B, C), features
+ """
+ device = vertex.device
+ rsize = int(self.rasterize_size)
+ ndc_proj = self.ndc_proj.to(device)
+ # trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v
+ if vertex.shape[-1] == 3:
+ vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1)
+ vertex[..., 1] = -vertex[..., 1]
+
+ vertex_ndc = vertex @ ndc_proj.t()
+ if self.ctx is None:
+ if self.use_opengl:
+ self.ctx = dr.RasterizeGLContext(device=device)
+ ctx_str = "opengl"
+ else:
+ self.ctx = dr.RasterizeCudaContext(device=device)
+ ctx_str = "cuda"
+ print("create %s ctx on device cuda:%d" % (ctx_str, device.index))
+
+ ranges = None
+ if isinstance(tri, List) or len(tri.shape) == 3:
+ vum = vertex_ndc.shape[1]
+ fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device)
+ fstartidx = torch.cumsum(fnum, dim=0) - fnum
+ ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu()
+ for i in range(tri.shape[0]):
+ tri[i] = tri[i] + i * vum
+ vertex_ndc = torch.cat(vertex_ndc, dim=0)
+ tri = torch.cat(tri, dim=0)
+
+ # for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3]
+ tri = tri.type(torch.int32).contiguous()
+ rast_out, _ = dr.rasterize(self.ctx, vertex_ndc.contiguous(), tri, resolution=[rsize, rsize], ranges=ranges)
+
+ depth, _ = dr.interpolate(vertex.reshape([-1, 4])[..., 2].unsqueeze(1).contiguous(), rast_out, tri)
+ depth = depth.permute(0, 3, 1, 2)
+ mask = (rast_out[..., 3] > 0).float().unsqueeze(1)
+ depth = mask * depth
+
+ image = None
+ if feat is not None:
+ image, _ = dr.interpolate(feat, rast_out, tri)
+ image = image.permute(0, 3, 1, 2)
+ image = mask * image
+
+ return mask, depth, image
diff --git a/Deep3DFaceRecon_pytorch/util/preprocess.py b/Deep3DFaceRecon_pytorch/util/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..e525c0f4f928a686ad3a751a697c806b3f3ba933
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/util/preprocess.py
@@ -0,0 +1,244 @@
+"""This script contains the image preprocessing code for Deep3DFaceRecon_pytorch
+"""
+import os
+import warnings
+
+import cv2
+import numpy as np
+import torch
+from PIL import Image
+from scipy.io import loadmat
+from skimage import transform as trans
+
+warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
+warnings.filterwarnings("ignore", category=FutureWarning)
+
+
+# calculating least square problem for image alignment
+def POS(xp, x):
+ npts = xp.shape[1]
+
+ A = np.zeros([2 * npts, 8])
+
+ A[0 : 2 * npts - 1 : 2, 0:3] = x.transpose()
+ A[0 : 2 * npts - 1 : 2, 3] = 1
+
+ A[1 : 2 * npts : 2, 4:7] = x.transpose()
+ A[1 : 2 * npts : 2, 7] = 1
+
+ b = np.reshape(xp.transpose(), [2 * npts, 1])
+
+ k, _, _, _ = np.linalg.lstsq(A, b)
+
+ R1 = k[0:3]
+ R2 = k[4:7]
+ sTx = k[3]
+ sTy = k[7]
+ s = (np.linalg.norm(R1) + np.linalg.norm(R2)) / 2
+ t = np.stack([sTx, sTy], axis=0)
+
+ return t, s
+
+
+# bounding box for 68 landmark detection
+def BBRegression(points, params):
+
+ w1 = params["W1"]
+ b1 = params["B1"]
+ w2 = params["W2"]
+ b2 = params["B2"]
+ data = points.copy()
+ data = data.reshape([5, 2])
+ data_mean = np.mean(data, axis=0)
+ x_mean = data_mean[0]
+ y_mean = data_mean[1]
+ data[:, 0] = data[:, 0] - x_mean
+ data[:, 1] = data[:, 1] - y_mean
+
+ rms = np.sqrt(np.sum(data**2) / 5)
+ data = data / rms
+ data = data.reshape([1, 10])
+ data = np.transpose(data)
+ inputs = np.matmul(w1, data) + b1
+ inputs = 2 / (1 + np.exp(-2 * inputs)) - 1
+ inputs = np.matmul(w2, inputs) + b2
+ inputs = np.transpose(inputs)
+ x = inputs[:, 0] * rms + x_mean
+ y = inputs[:, 1] * rms + y_mean
+ w = 224 / inputs[:, 2] * rms
+ rects = [x, y, w, w]
+ return np.array(rects).reshape([4])
+
+
+# utils for landmark detection
+def img_padding(img, box):
+ success = True
+ bbox = box.copy()
+ res = np.zeros([2 * img.shape[0], 2 * img.shape[1], 3])
+ res[
+ img.shape[0] // 2 : img.shape[0] + img.shape[0] // 2, img.shape[1] // 2 : img.shape[1] + img.shape[1] // 2
+ ] = img
+
+ bbox[0] = bbox[0] + img.shape[1] // 2
+ bbox[1] = bbox[1] + img.shape[0] // 2
+ if bbox[0] < 0 or bbox[1] < 0:
+ success = False
+ return res, bbox, success
+
+
+# utils for landmark detection
+def crop(img, bbox):
+ padded_img, padded_bbox, flag = img_padding(img, bbox)
+ if flag:
+ crop_img = padded_img[
+ padded_bbox[1] : padded_bbox[1] + padded_bbox[3], padded_bbox[0] : padded_bbox[0] + padded_bbox[2]
+ ]
+ crop_img = cv2.resize(crop_img.astype(np.uint8), (224, 224), interpolation=cv2.INTER_CUBIC)
+ scale = 224 / padded_bbox[3]
+ return crop_img, scale
+ else:
+ return padded_img, 0
+
+
+# utils for landmark detection
+def scale_trans(img, lm, t, s):
+ imgw = img.shape[1]
+ imgh = img.shape[0]
+ M_s = np.array([[1, 0, -t[0] + imgw // 2 + 0.5], [0, 1, -imgh // 2 + t[1]]], dtype=np.float32)
+ img = cv2.warpAffine(img, M_s, (imgw, imgh))
+ w = int(imgw / s * 100)
+ h = int(imgh / s * 100)
+ img = cv2.resize(img, (w, h))
+ lm = np.stack([lm[:, 0] - t[0] + imgw // 2, lm[:, 1] - t[1] + imgh // 2], axis=1) / s * 100
+
+ left = w // 2 - 112
+ up = h // 2 - 112
+ bbox = [left, up, 224, 224]
+ cropped_img, scale2 = crop(img, bbox)
+ assert scale2 != 0
+ t1 = np.array([bbox[0], bbox[1]])
+
+ # back to raw img s * crop + s * t1 + t2
+ t1 = np.array([w // 2 - 112, h // 2 - 112])
+ scale = s / 100
+ t2 = np.array([t[0] - imgw / 2, t[1] - imgh / 2])
+ inv = (scale / scale2, scale * t1 + t2.reshape([2]))
+ return cropped_img, inv
+
+
+# utils for landmark detection
+def align_for_lm(img, five_points):
+ five_points = np.array(five_points).reshape([1, 10])
+ params = loadmat("util/BBRegressorParam_r.mat")
+ bbox = BBRegression(five_points, params)
+ assert bbox[2] != 0
+ bbox = np.round(bbox).astype(np.int32)
+ crop_img, scale = crop(img, bbox)
+ return crop_img, scale, bbox
+
+
+# resize and crop images for face reconstruction
+def resize_n_crop_img(img, lm, t, s, target_size=224.0, mask=None):
+ w0, h0 = img.size
+ w = (w0 * s).astype(np.int32)
+ h = (h0 * s).astype(np.int32)
+ left = (w / 2 - target_size / 2 + float((t[0] - w0 / 2) * s)).astype(np.int32)
+ right = left + target_size
+ up = (h / 2 - target_size / 2 + float((h0 / 2 - t[1]) * s)).astype(np.int32)
+ below = up + target_size
+
+ img = img.resize((w, h), resample=Image.Resampling.BICUBIC)
+ img = img.crop((left, up, right, below))
+
+ if mask is not None:
+ mask = mask.resize((w, h), resample=Image.Resampling.BICUBIC)
+ mask = mask.crop((left, up, right, below))
+
+ lm = np.stack([lm[:, 0] - t[0] + w0 / 2, lm[:, 1] - t[1] + h0 / 2], axis=1) * s
+ lm = lm - np.reshape(np.array([(w / 2 - target_size / 2), (h / 2 - target_size / 2)]), [1, 2])
+
+ return img, lm, mask
+
+
+# utils for face reconstruction
+def extract_5p(lm):
+ lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1
+ lm5p = np.stack(
+ [
+ lm[lm_idx[0], :],
+ np.mean(lm[lm_idx[[1, 2]], :], 0),
+ np.mean(lm[lm_idx[[3, 4]], :], 0),
+ lm[lm_idx[5], :],
+ lm[lm_idx[6], :],
+ ],
+ axis=0,
+ )
+ lm5p = lm5p[[1, 2, 0, 3, 4], :]
+ return lm5p
+
+
+# utils for face reconstruction
+def align_img(img, lm, lm3D, mask=None, target_size=224.0, rescale_factor=102.0):
+ """
+ Return:
+ transparams --numpy.array (raw_W, raw_H, scale, tx, ty)
+ img_new --PIL.Image (target_size, target_size, 3)
+ lm_new --numpy.array (68, 2), y direction is opposite to v direction
+ mask_new --PIL.Image (target_size, target_size)
+
+ Parameters:
+ img --PIL.Image (raw_H, raw_W, 3)
+ lm --numpy.array (68, 2), y direction is opposite to v direction
+ lm3D --numpy.array (5, 3)
+ mask --PIL.Image (raw_H, raw_W, 3)
+ """
+
+ w0, h0 = img.size
+ if lm.shape[0] != 5:
+ lm5p = extract_5p(lm)
+ else:
+ lm5p = lm
+
+ # calculate translation and scale factors using 5 facial landmarks and standard landmarks of a 3D face
+ t, s = POS(lm5p.transpose(), lm3D.transpose())
+ s = rescale_factor / s
+
+ # processing the image
+ img_new, lm_new, mask_new = resize_n_crop_img(img, lm, t, s, target_size=target_size, mask=mask)
+ trans_params = np.array([w0, h0, s, t[0], t[1]])
+
+ return trans_params, img_new, lm_new, mask_new
+
+
+# utils for face recognition model
+def estimate_norm(lm_68p, H):
+ # from https://github.com/deepinsight/insightface/blob/c61d3cd208a603dfa4a338bd743b320ce3e94730/recognition/common/face_align.py#L68
+ """
+ Return:
+ trans_m --numpy.array (2, 3)
+ Parameters:
+ lm --numpy.array (68, 2), y direction is opposite to v direction
+ H --int/float , image height
+ """
+ lm = extract_5p(lm_68p)
+ lm[:, -1] = H - 1 - lm[:, -1]
+ tform = trans.SimilarityTransform()
+ src = np.array(
+ [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366], [41.5493, 92.3655], [70.7299, 92.2041]],
+ dtype=np.float32,
+ )
+ tform.estimate(lm, src)
+ M = tform.params
+ if np.linalg.det(M) == 0:
+ M = np.eye(3)
+
+ return M[0:2, :]
+
+
+def estimate_norm_torch(lm_68p, H):
+ lm_68p_ = lm_68p.detach().cpu().numpy()
+ M = []
+ for i in range(lm_68p_.shape[0]):
+ M.append(estimate_norm(lm_68p_[i], H))
+ M = torch.tensor(np.array(M), dtype=torch.float32).to(lm_68p.device)
+ return M
diff --git a/Deep3DFaceRecon_pytorch/util/skin_mask.py b/Deep3DFaceRecon_pytorch/util/skin_mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe8f2eb584fc94c617f244a18268125d3968c029
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/util/skin_mask.py
@@ -0,0 +1,179 @@
+"""This script is to generate skin attention mask for Deep3DFaceRecon_pytorch
+"""
+import math
+import os
+
+import cv2
+import numpy as np
+
+
+class GMM:
+ def __init__(self, dim, num, w, mu, cov, cov_det, cov_inv):
+ self.dim = dim # feature dimension
+ self.num = num # number of Gaussian components
+ self.w = w # weights of Gaussian components (a list of scalars)
+ self.mu = mu # mean of Gaussian components (a list of 1xdim vectors)
+ self.cov = cov # covariance matrix of Gaussian components (a list of dimxdim matrices)
+ self.cov_det = cov_det # pre-computed determinet of covariance matrices (a list of scalars)
+ self.cov_inv = cov_inv # pre-computed inverse covariance matrices (a list of dimxdim matrices)
+
+ self.factor = [0] * num
+ for i in range(self.num):
+ self.factor[i] = (2 * math.pi) ** (self.dim / 2) * self.cov_det[i] ** 0.5
+
+ def likelihood(self, data):
+ assert data.shape[1] == self.dim
+ N = data.shape[0]
+ lh = np.zeros(N)
+
+ for i in range(self.num):
+ data_ = data - self.mu[i]
+
+ tmp = np.matmul(data_, self.cov_inv[i]) * data_
+ tmp = np.sum(tmp, axis=1)
+ power = -0.5 * tmp
+
+ p = np.array([math.exp(power[j]) for j in range(N)])
+ p = p / self.factor[i]
+ lh += p * self.w[i]
+
+ return lh
+
+
+def _rgb2ycbcr(rgb):
+ m = np.array([[65.481, 128.553, 24.966], [-37.797, -74.203, 112], [112, -93.786, -18.214]])
+ shape = rgb.shape
+ rgb = rgb.reshape((shape[0] * shape[1], 3))
+ ycbcr = np.dot(rgb, m.transpose() / 255.0)
+ ycbcr[:, 0] += 16.0
+ ycbcr[:, 1:] += 128.0
+ return ycbcr.reshape(shape)
+
+
+def _bgr2ycbcr(bgr):
+ rgb = bgr[..., ::-1]
+ return _rgb2ycbcr(rgb)
+
+
+gmm_skin_w = [0.24063933, 0.16365987, 0.26034665, 0.33535415]
+gmm_skin_mu = [
+ np.array([113.71862, 103.39613, 164.08226]),
+ np.array([150.19858, 105.18467, 155.51428]),
+ np.array([183.92976, 107.62468, 152.71820]),
+ np.array([114.90524, 113.59782, 151.38217]),
+]
+gmm_skin_cov_det = [5692842.5, 5851930.5, 2329131.0, 1585971.0]
+gmm_skin_cov_inv = [
+ np.array(
+ [
+ [0.0019472069, 0.0020450759, -0.00060243998],
+ [0.0020450759, 0.017700525, 0.0051420014],
+ [-0.00060243998, 0.0051420014, 0.0081308950],
+ ]
+ ),
+ np.array(
+ [
+ [0.0027110141, 0.0011036990, 0.0023122299],
+ [0.0011036990, 0.010707724, 0.010742856],
+ [0.0023122299, 0.010742856, 0.017481629],
+ ]
+ ),
+ np.array(
+ [
+ [0.0048026871, 0.00022935172, 0.0077668377],
+ [0.00022935172, 0.011729696, 0.0081661865],
+ [0.0077668377, 0.0081661865, 0.025374353],
+ ]
+ ),
+ np.array(
+ [
+ [0.0011989699, 0.0022453172, -0.0010748957],
+ [0.0022453172, 0.047758564, 0.020332102],
+ [-0.0010748957, 0.020332102, 0.024502251],
+ ]
+ ),
+]
+
+gmm_skin = GMM(3, 4, gmm_skin_w, gmm_skin_mu, [], gmm_skin_cov_det, gmm_skin_cov_inv)
+
+gmm_nonskin_w = [0.12791070, 0.31130761, 0.34245777, 0.21832393]
+gmm_nonskin_mu = [
+ np.array([99.200851, 112.07533, 140.20602]),
+ np.array([110.91392, 125.52969, 130.19237]),
+ np.array([129.75864, 129.96107, 126.96808]),
+ np.array([112.29587, 128.85121, 129.05431]),
+]
+gmm_nonskin_cov_det = [458703648.0, 6466488.0, 90611376.0, 133097.63]
+gmm_nonskin_cov_inv = [
+ np.array(
+ [
+ [0.00085371657, 0.00071197288, 0.00023958916],
+ [0.00071197288, 0.0025935620, 0.00076557708],
+ [0.00023958916, 0.00076557708, 0.0015042332],
+ ]
+ ),
+ np.array(
+ [
+ [0.00024650150, 0.00045542428, 0.00015019422],
+ [0.00045542428, 0.026412144, 0.018419769],
+ [0.00015019422, 0.018419769, 0.037497383],
+ ]
+ ),
+ np.array(
+ [
+ [0.00037054974, 0.00038146760, 0.00040408765],
+ [0.00038146760, 0.0085505722, 0.0079136286],
+ [0.00040408765, 0.0079136286, 0.010982352],
+ ]
+ ),
+ np.array(
+ [
+ [0.00013709733, 0.00051228428, 0.00012777430],
+ [0.00051228428, 0.28237113, 0.10528370],
+ [0.00012777430, 0.10528370, 0.23468947],
+ ]
+ ),
+]
+
+gmm_nonskin = GMM(3, 4, gmm_nonskin_w, gmm_nonskin_mu, [], gmm_nonskin_cov_det, gmm_nonskin_cov_inv)
+
+prior_skin = 0.8
+prior_nonskin = 1 - prior_skin
+
+
+# calculate skin attention mask
+def skinmask(imbgr):
+ im = _bgr2ycbcr(imbgr)
+
+ data = im.reshape((-1, 3))
+
+ lh_skin = gmm_skin.likelihood(data)
+ lh_nonskin = gmm_nonskin.likelihood(data)
+
+ tmp1 = prior_skin * lh_skin
+ tmp2 = prior_nonskin * lh_nonskin
+ post_skin = tmp1 / (tmp1 + tmp2) # posterior probability
+
+ post_skin = post_skin.reshape((im.shape[0], im.shape[1]))
+
+ post_skin = np.round(post_skin * 255)
+ post_skin = post_skin.astype(np.uint8)
+ post_skin = np.tile(np.expand_dims(post_skin, 2), [1, 1, 3]) # reshape to H*W*3
+
+ return post_skin
+
+
+def get_skin_mask(img_path):
+ print("generating skin masks......")
+ names = [i for i in sorted(os.listdir(img_path)) if "jpg" in i or "png" in i or "jpeg" in i or "PNG" in i]
+ save_path = os.path.join(img_path, "mask")
+ if not os.path.isdir(save_path):
+ os.makedirs(save_path)
+
+ for i in range(0, len(names)):
+ name = names[i]
+ print("%05d" % (i), " ", name)
+ full_image_name = os.path.join(img_path, name)
+ img = cv2.imread(full_image_name).astype(np.float32)
+ skin_img = skinmask(img)
+ cv2.imwrite(os.path.join(save_path, name), skin_img.astype(np.uint8))
diff --git a/Deep3DFaceRecon_pytorch/util/test_mean_face.txt b/Deep3DFaceRecon_pytorch/util/test_mean_face.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3a46d4db7699ffed8f898fcee64099631509946d
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/util/test_mean_face.txt
@@ -0,0 +1,136 @@
+-5.228591537475585938e+01
+2.078247070312500000e-01
+-5.064269638061523438e+01
+-1.315765380859375000e+01
+-4.952939224243164062e+01
+-2.592591094970703125e+01
+-4.793047332763671875e+01
+-3.832135772705078125e+01
+-4.512159729003906250e+01
+-5.059623336791992188e+01
+-3.917720794677734375e+01
+-6.043736648559570312e+01
+-2.929953765869140625e+01
+-6.861183166503906250e+01
+-1.719801330566406250e+01
+-7.572736358642578125e+01
+-1.961936950683593750e+00
+-7.862001037597656250e+01
+1.467941284179687500e+01
+-7.607844543457031250e+01
+2.744073486328125000e+01
+-6.915261840820312500e+01
+3.855677795410156250e+01
+-5.950350570678710938e+01
+4.478240966796875000e+01
+-4.867547225952148438e+01
+4.714337158203125000e+01
+-3.800830078125000000e+01
+4.940315246582031250e+01
+-2.496297454833984375e+01
+5.117234802246093750e+01
+-1.241538238525390625e+01
+5.190507507324218750e+01
+8.244247436523437500e-01
+-4.150688934326171875e+01
+2.386329650878906250e+01
+-3.570307159423828125e+01
+3.017010498046875000e+01
+-2.790358734130859375e+01
+3.212951660156250000e+01
+-1.941773223876953125e+01
+3.156523132324218750e+01
+-1.138106536865234375e+01
+2.841992187500000000e+01
+5.993263244628906250e+00
+2.895182800292968750e+01
+1.343590545654296875e+01
+3.189880371093750000e+01
+2.203153991699218750e+01
+3.302221679687500000e+01
+2.992478942871093750e+01
+3.099150085449218750e+01
+3.628388977050781250e+01
+2.765748596191406250e+01
+-1.933914184570312500e+00
+1.405374145507812500e+01
+-2.153038024902343750e+00
+5.772636413574218750e+00
+-2.270050048828125000e+00
+-2.121643066406250000e+00
+-2.218330383300781250e+00
+-1.068978118896484375e+01
+-1.187252044677734375e+01
+-1.997912597656250000e+01
+-6.879402160644531250e+00
+-2.143579864501953125e+01
+-1.227821350097656250e+00
+-2.193494415283203125e+01
+4.623237609863281250e+00
+-2.152721405029296875e+01
+9.721397399902343750e+00
+-1.953671264648437500e+01
+-3.648714447021484375e+01
+9.811126708984375000e+00
+-3.130242919921875000e+01
+1.422447967529296875e+01
+-2.212834930419921875e+01
+1.493019866943359375e+01
+-1.500880432128906250e+01
+1.073588562011718750e+01
+-2.095037078857421875e+01
+9.054298400878906250e+00
+-3.050099182128906250e+01
+8.704177856445312500e+00
+1.173237609863281250e+01
+1.054329681396484375e+01
+1.856353759765625000e+01
+1.535009765625000000e+01
+2.893331909179687500e+01
+1.451992797851562500e+01
+3.452944946289062500e+01
+1.065280151367187500e+01
+2.875990295410156250e+01
+8.654792785644531250e+00
+1.942100524902343750e+01
+9.422447204589843750e+00
+-2.204488372802734375e+01
+-3.983994293212890625e+01
+-1.324458312988281250e+01
+-3.467377471923828125e+01
+-6.749649047851562500e+00
+-3.092894744873046875e+01
+-9.183349609375000000e-01
+-3.196458435058593750e+01
+4.220649719238281250e+00
+-3.090406036376953125e+01
+1.089889526367187500e+01
+-3.497008514404296875e+01
+1.874589538574218750e+01
+-4.065438079833984375e+01
+1.124106597900390625e+01
+-4.438417816162109375e+01
+5.181709289550781250e+00
+-4.649170684814453125e+01
+-1.158607482910156250e+00
+-4.680406951904296875e+01
+-7.918922424316406250e+00
+-4.671575164794921875e+01
+-1.452505493164062500e+01
+-4.416526031494140625e+01
+-2.005007171630859375e+01
+-3.997841644287109375e+01
+-1.054919433593750000e+01
+-3.849683380126953125e+01
+-1.051826477050781250e+00
+-3.794863128662109375e+01
+6.412681579589843750e+00
+-3.804645538330078125e+01
+1.627674865722656250e+01
+-4.039697265625000000e+01
+6.373878479003906250e+00
+-4.087213897705078125e+01
+-8.551712036132812500e-01
+-4.157129669189453125e+01
+-1.014953613281250000e+01
+-4.128469085693359375e+01
diff --git a/Deep3DFaceRecon_pytorch/util/util.py b/Deep3DFaceRecon_pytorch/util/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3cddd9c423a068cc8f0b008bbda902334045b1e
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/util/util.py
@@ -0,0 +1,220 @@
+"""This script contains basic utilities for Deep3DFaceRecon_pytorch
+"""
+from __future__ import print_function
+
+import argparse
+import importlib
+import os
+from argparse import Namespace
+
+import numpy as np
+import torch
+import torchvision
+from PIL import Image
+
+
+def str2bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ("yes", "true", "t", "y", "1"):
+ return True
+ elif v.lower() in ("no", "false", "f", "n", "0"):
+ return False
+ else:
+ raise argparse.ArgumentTypeError("Boolean value expected.")
+
+
+def copyconf(default_opt, **kwargs):
+ conf = Namespace(**vars(default_opt))
+ for key in kwargs:
+ setattr(conf, key, kwargs[key])
+ return conf
+
+
+def genvalconf(train_opt, **kwargs):
+ conf = Namespace(**vars(train_opt))
+ attr_dict = train_opt.__dict__
+ for key, value in attr_dict.items():
+ if "val" in key and key.split("_")[0] in attr_dict:
+ setattr(conf, key.split("_")[0], value)
+
+ for key in kwargs:
+ setattr(conf, key, kwargs[key])
+
+ return conf
+
+
+def find_class_in_module(target_cls_name, module):
+ target_cls_name = target_cls_name.replace("_", "").lower()
+ clslib = importlib.import_module(module)
+ cls = None
+ for name, clsobj in clslib.__dict__.items():
+ if name.lower() == target_cls_name:
+ cls = clsobj
+
+ assert (
+ cls is not None
+ ), "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (
+ module,
+ target_cls_name,
+ )
+
+ return cls
+
+
+def tensor2im(input_image, imtype=np.uint8):
+ """ "Converts a Tensor array into a numpy image array.
+
+ Parameters:
+ input_image (tensor) -- the input image tensor array, range(0, 1)
+ imtype (type) -- the desired type of the converted numpy array
+ """
+ if not isinstance(input_image, np.ndarray):
+ if isinstance(input_image, torch.Tensor): # get the data from a variable
+ image_tensor = input_image.data
+ else:
+ return input_image
+ image_numpy = image_tensor.clamp(0.0, 1.0).cpu().float().numpy() # convert it into a numpy array
+ if image_numpy.shape[0] == 1: # grayscale to RGB
+ image_numpy = np.tile(image_numpy, (3, 1, 1))
+ image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 # post-processing: tranpose and scaling
+ else: # if it is a numpy array, do nothing
+ image_numpy = input_image
+ return image_numpy.astype(imtype)
+
+
+def diagnose_network(net, name="network"):
+ """Calculate and print the mean of average absolute(gradients)
+
+ Parameters:
+ net (torch network) -- Torch network
+ name (str) -- the name of the network
+ """
+ mean = 0.0
+ count = 0
+ for param in net.parameters():
+ if param.grad is not None:
+ mean += torch.mean(torch.abs(param.grad.data))
+ count += 1
+ if count > 0:
+ mean = mean / count
+ print(name)
+ print(mean)
+
+
+def save_image(image_numpy, image_path, aspect_ratio=1.0):
+ """Save a numpy image to the disk
+
+ Parameters:
+ image_numpy (numpy array) -- input numpy array
+ image_path (str) -- the path of the image
+ """
+
+ image_pil = Image.fromarray(image_numpy)
+ h, w, _ = image_numpy.shape
+
+ if aspect_ratio is None:
+ pass
+ elif aspect_ratio > 1.0:
+ image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.Resampling.BICUBIC)
+ elif aspect_ratio < 1.0:
+ image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.Resampling.BICUBIC)
+ image_pil.save(image_path)
+
+
+def print_numpy(x, val=True, shp=False):
+ """Print the mean, min, max, median, std, and size of a numpy array
+
+ Parameters:
+ val (bool) -- if print the values of the numpy array
+ shp (bool) -- if print the shape of the numpy array
+ """
+ x = x.astype(np.float64)
+ if shp:
+ print("shape,", x.shape)
+ if val:
+ x = x.flatten()
+ print(
+ "mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f"
+ % (np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))
+ )
+
+
+def mkdirs(paths):
+ """create empty directories if they don't exist
+
+ Parameters:
+ paths (str list) -- a list of directory paths
+ """
+ if isinstance(paths, list) and not isinstance(paths, str):
+ for path in paths:
+ mkdir(path)
+ else:
+ mkdir(paths)
+
+
+def mkdir(path):
+ """create a single empty directory if it didn't exist
+
+ Parameters:
+ path (str) -- a single directory path
+ """
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+
+def correct_resize_label(t, size):
+ device = t.device
+ t = t.detach().cpu()
+ resized = []
+ for i in range(t.size(0)):
+ one_t = t[i, :1]
+ one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0))
+ one_np = one_np[:, :, 0]
+ one_image = Image.fromarray(one_np).resize(size, Image.NEAREST)
+ resized_t = torch.from_numpy(np.array(one_image)).long()
+ resized.append(resized_t)
+ return torch.stack(resized, dim=0).to(device)
+
+
+def correct_resize(t, size, mode=Image.Resampling.BICUBIC):
+ device = t.device
+ t = t.detach().cpu()
+ resized = []
+ for i in range(t.size(0)):
+ one_t = t[i : i + 1]
+ one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.Resampling.BICUBIC)
+ resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0
+ resized.append(resized_t)
+ return torch.stack(resized, dim=0).to(device)
+
+
+def draw_landmarks(img, landmark, color="r", step=2):
+ """
+ Return:
+ img -- numpy.array, (B, H, W, 3) img with landmark, RGB order, range (0, 255)
+
+
+ Parameters:
+ img -- numpy.array, (B, H, W, 3), RGB order, range (0, 255)
+ landmark -- numpy.array, (B, 68, 2), y direction is opposite to v direction
+ color -- str, 'r' or 'b' (red or blue)
+ """
+ if color == "r":
+ c = np.array([255.0, 0, 0])
+ else:
+ c = np.array([0, 0, 255.0])
+
+ _, H, W, _ = img.shape
+ img, landmark = img.copy(), landmark.copy()
+ landmark[..., 1] = H - 1 - landmark[..., 1]
+ landmark = np.round(landmark).astype(np.int32)
+ for i in range(landmark.shape[1]):
+ x, y = landmark[:, i, 0], landmark[:, i, 1]
+ for j in range(-step, step):
+ for k in range(-step, step):
+ u = np.clip(x + j, 0, W - 1)
+ v = np.clip(y + k, 0, H - 1)
+ for m in range(landmark.shape[0]):
+ img[m, v[m], u[m]] = c
+ return img
diff --git a/Deep3DFaceRecon_pytorch/util/visualizer.py b/Deep3DFaceRecon_pytorch/util/visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa6bff6decff9dd5b3e80233ae15ef4e35fe0b70
--- /dev/null
+++ b/Deep3DFaceRecon_pytorch/util/visualizer.py
@@ -0,0 +1,237 @@
+"""This script defines the visualizer for Deep3DFaceRecon_pytorch
+"""
+import ntpath
+import os
+import sys
+import time
+from subprocess import PIPE
+from subprocess import Popen
+
+import numpy as np
+from torch.utils.tensorboard import SummaryWriter
+
+from . import html
+from . import util
+
+
+def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
+ """Save images to the disk.
+
+ Parameters:
+ webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
+ visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
+ image_path (str) -- the string is used to create image paths
+ aspect_ratio (float) -- the aspect ratio of saved images
+ width (int) -- the images will be resized to width x width
+
+ This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
+ """
+ image_dir = webpage.get_image_dir()
+ short_path = ntpath.basename(image_path[0])
+ name = os.path.splitext(short_path)[0]
+
+ webpage.add_header(name)
+ ims, txts, links = [], [], []
+
+ for label, im_data in visuals.items():
+ im = util.tensor2im(im_data)
+ image_name = "%s/%s.png" % (label, name)
+ os.makedirs(os.path.join(image_dir, label), exist_ok=True)
+ save_path = os.path.join(image_dir, image_name)
+ util.save_image(im, save_path, aspect_ratio=aspect_ratio)
+ ims.append(image_name)
+ txts.append(label)
+ links.append(image_name)
+ webpage.add_images(ims, txts, links, width=width)
+
+
+class Visualizer:
+ """This class includes several functions that can display/save images and print/save logging information.
+
+ It uses a Python library tensprboardX for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
+ """
+
+ def __init__(self, opt):
+ """Initialize the Visualizer class
+
+ Parameters:
+ opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ Step 1: Cache the training/test options
+ Step 2: create a tensorboard writer
+ Step 3: create an HTML object for saveing HTML filters
+ Step 4: create a logging file to store training losses
+ """
+ self.opt = opt # cache the option
+ self.use_html = opt.isTrain and not opt.no_html
+ self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, "logs", opt.name))
+ self.win_size = opt.display_winsize
+ self.name = opt.name
+ self.saved = False
+ if (
+ self.use_html
+ ): # create an HTML object at /web/; images will be saved under /web/images/
+ self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, "web")
+ self.img_dir = os.path.join(self.web_dir, "images")
+ print("create web directory %s..." % self.web_dir)
+ util.mkdirs([self.web_dir, self.img_dir])
+ # create a logging file to store training losses
+ self.log_name = os.path.join(opt.checkpoints_dir, opt.name, "loss_log.txt")
+ with open(self.log_name, "a") as log_file:
+ now = time.strftime("%c")
+ log_file.write("================ Training Loss (%s) ================\n" % now)
+
+ def reset(self):
+ """Reset the self.saved status"""
+ self.saved = False
+
+ def display_current_results(self, visuals, total_iters, epoch, save_result):
+ """Display current results on tensorboad; save current results to an HTML file.
+
+ Parameters:
+ visuals (OrderedDict) - - dictionary of images to display or save
+ total_iters (int) -- total iterations
+ epoch (int) - - the current epoch
+ save_result (bool) - - if save the current results to an HTML file
+ """
+ for label, image in visuals.items():
+ self.writer.add_image(label, util.tensor2im(image), total_iters, dataformats="HWC")
+
+ if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
+ self.saved = True
+ # save images to the disk
+ for label, image in visuals.items():
+ image_numpy = util.tensor2im(image)
+ img_path = os.path.join(self.img_dir, "epoch%.3d_%s.png" % (epoch, label))
+ util.save_image(image_numpy, img_path)
+
+ # update website
+ webpage = html.HTML(self.web_dir, "Experiment name = %s" % self.name, refresh=0)
+ for n in range(epoch, 0, -1):
+ webpage.add_header("epoch [%d]" % n)
+ ims, txts, links = [], [], []
+
+ for label, image_numpy in visuals.items():
+ image_numpy = util.tensor2im(image)
+ img_path = "epoch%.3d_%s.png" % (n, label)
+ ims.append(img_path)
+ txts.append(label)
+ links.append(img_path)
+ webpage.add_images(ims, txts, links, width=self.win_size)
+ webpage.save()
+
+ def plot_current_losses(self, total_iters, losses):
+ # G_loss_collection = {}
+ # D_loss_collection = {}
+ # for name, value in losses.items():
+ # if 'G' in name or 'NCE' in name or 'idt' in name:
+ # G_loss_collection[name] = value
+ # else:
+ # D_loss_collection[name] = value
+ # self.writer.add_scalars('G_collec', G_loss_collection, total_iters)
+ # self.writer.add_scalars('D_collec', D_loss_collection, total_iters)
+ for name, value in losses.items():
+ self.writer.add_scalar(name, value, total_iters)
+
+ # losses: same format as |losses| of plot_current_losses
+ def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
+ """print current losses on console; also save the losses to the disk
+
+ Parameters:
+ epoch (int) -- current epoch
+ iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
+ t_comp (float) -- computational time per data point (normalized by batch_size)
+ t_data (float) -- data loading time per data point (normalized by batch_size)
+ """
+ message = "(epoch: %d, iters: %d, time: %.3f, data: %.3f) " % (epoch, iters, t_comp, t_data)
+ for k, v in losses.items():
+ message += "%s: %.3f " % (k, v)
+
+ print(message) # print the message
+ with open(self.log_name, "a") as log_file:
+ log_file.write("%s\n" % message) # save the message
+
+
+class MyVisualizer:
+ def __init__(self, opt):
+ """Initialize the Visualizer class
+
+ Parameters:
+ opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ Step 1: Cache the training/test options
+ Step 2: create a tensorboard writer
+ Step 3: create an HTML object for saveing HTML filters
+ Step 4: create a logging file to store training losses
+ """
+ self.opt = opt # cache the optio
+ self.name = opt.name
+ self.img_dir = os.path.join(opt.checkpoints_dir, opt.name, "results")
+
+ if opt.phase != "test":
+ self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, "logs"))
+ # create a logging file to store training losses
+ self.log_name = os.path.join(opt.checkpoints_dir, opt.name, "loss_log.txt")
+ with open(self.log_name, "a") as log_file:
+ now = time.strftime("%c")
+ log_file.write("================ Training Loss (%s) ================\n" % now)
+
+ def display_current_results(
+ self, visuals, total_iters, epoch, dataset="train", save_results=False, count=0, name=None, add_image=True
+ ):
+ """Display current results on tensorboad; save current results to an HTML file.
+
+ Parameters:
+ visuals (OrderedDict) - - dictionary of images to display or save
+ total_iters (int) -- total iterations
+ epoch (int) - - the current epoch
+ dataset (str) - - 'train' or 'val' or 'test'
+ """
+ # if (not add_image) and (not save_results): return
+
+ for label, image in visuals.items():
+ for i in range(image.shape[0]):
+ image_numpy = util.tensor2im(image[i])
+ if add_image:
+ self.writer.add_image(
+ label + "%s_%02d" % (dataset, i + count), image_numpy, total_iters, dataformats="HWC"
+ )
+
+ if save_results:
+ save_path = os.path.join(self.img_dir, dataset, "epoch_%s_%06d" % (epoch, total_iters))
+ if not os.path.isdir(save_path):
+ os.makedirs(save_path)
+
+ if name is not None:
+ img_path = os.path.join(save_path, "%s.png" % name)
+ else:
+ img_path = os.path.join(save_path, "%s_%03d.png" % (label, i + count))
+ util.save_image(image_numpy, img_path)
+
+ def plot_current_losses(self, total_iters, losses, dataset="train"):
+ for name, value in losses.items():
+ self.writer.add_scalar(name + "/%s" % dataset, value, total_iters)
+
+ # losses: same format as |losses| of plot_current_losses
+ def print_current_losses(self, epoch, iters, losses, t_comp, t_data, dataset="train"):
+ """print current losses on console; also save the losses to the disk
+
+ Parameters:
+ epoch (int) -- current epoch
+ iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
+ t_comp (float) -- computational time per data point (normalized by batch_size)
+ t_data (float) -- data loading time per data point (normalized by batch_size)
+ """
+ message = "(dataset: %s, epoch: %d, iters: %d, time: %.3f, data: %.3f) " % (
+ dataset,
+ epoch,
+ iters,
+ t_comp,
+ t_data,
+ )
+ for k, v in losses.items():
+ message += "%s: %.3f " % (k, v)
+
+ print(message) # print the message
+ with open(self.log_name, "a") as log_file:
+ log_file.write("%s\n" % message) # save the message
diff --git a/HRNet/__pycache__/hrnet.cpython-310.pyc b/HRNet/__pycache__/hrnet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1b6585f9f18439a78ae9ddb0cbc084a237e8221a
Binary files /dev/null and b/HRNet/__pycache__/hrnet.cpython-310.pyc differ
diff --git a/HRNet/hrnet.py b/HRNet/hrnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..9de9acc42e4b7b1077fd691542b1969f54a363df
--- /dev/null
+++ b/HRNet/hrnet.py
@@ -0,0 +1,394 @@
+import logging
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+BatchNorm2d = nn.BatchNorm2d
+BN_MOMENTUM = 0.01
+logger = logging.getLogger(__name__)
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
+ self.bn3 = BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class HighResolutionModule(nn.Module):
+ def __init__(
+ self, num_branches, blocks, num_blocks, num_inchannels, num_channels, fuse_method, multi_scale_output=True
+ ):
+ super(HighResolutionModule, self).__init__()
+ # self._check_branches(
+ # num_branches, blocks, num_blocks, num_inchannels, num_channels)
+
+ self.num_inchannels = num_inchannels
+ self.fuse_method = fuse_method
+ self.num_branches = num_branches
+
+ self.multi_scale_output = multi_scale_output
+
+ self.branches = self._make_branches(num_branches, blocks, num_blocks, num_channels)
+ self.fuse_layers = self._make_fuse_layers()
+ self.relu = nn.ReLU(inplace=True)
+
+ # def _check_branches(self, num_branches, blocks, num_blocks,
+ # num_inchannels, num_channels):
+ # if num_branches != len(num_blocks):
+ # error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
+ # num_branches, len(num_blocks))
+ # logger.error(error_msg)
+ # raise ValueError(error_msg)
+
+ # if num_branches != len(num_channels):
+ # error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
+ # num_branches, len(num_channels))
+ # logger.error(error_msg)
+ # raise ValueError(error_msg)
+
+ # if num_branches != len(num_inchannels):
+ # error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
+ # num_branches, len(num_inchannels))
+ # logger.error(error_msg)
+ # raise ValueError(error_msg)
+
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1):
+ downsample = None
+ if stride != 1 or self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(
+ self.num_inchannels[branch_index],
+ num_channels[branch_index] * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False,
+ ),
+ BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM),
+ )
+
+ layers = []
+ layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample))
+ self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion
+ for i in range(1, num_blocks[branch_index]):
+ layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index]))
+
+ return nn.Sequential(*layers)
+
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
+ branches = []
+
+ for i in range(num_branches):
+ branches.append(self._make_one_branch(i, block, num_blocks, num_channels))
+
+ return nn.ModuleList(branches)
+
+ def _make_fuse_layers(self):
+ if self.num_branches == 1:
+ return None
+
+ num_branches = self.num_branches
+ num_inchannels = self.num_inchannels
+ fuse_layers = []
+ for i in range(num_branches if self.multi_scale_output else 1):
+ fuse_layer = []
+ for j in range(num_branches):
+ if j > i:
+ fuse_layer.append(
+ nn.Sequential(
+ nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False),
+ BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM),
+ )
+ )
+ # nn.Upsample(scale_factor=2**(j-i), mode='nearest')))
+ elif j == i:
+ fuse_layer.append(None)
+ else:
+ conv3x3s = []
+ for k in range(i - j):
+ if k == i - j - 1:
+ num_outchannels_conv3x3 = num_inchannels[i]
+ conv3x3s.append(
+ nn.Sequential(
+ nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False),
+ BatchNorm2d(num_outchannels_conv3x3, momentum=BN_MOMENTUM),
+ )
+ )
+ else:
+ num_outchannels_conv3x3 = num_inchannels[j]
+ conv3x3s.append(
+ nn.Sequential(
+ nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False),
+ BatchNorm2d(num_outchannels_conv3x3, momentum=BN_MOMENTUM),
+ nn.ReLU(inplace=True),
+ )
+ )
+ fuse_layer.append(nn.Sequential(*conv3x3s))
+ fuse_layers.append(nn.ModuleList(fuse_layer))
+
+ return nn.ModuleList(fuse_layers)
+
+ def get_num_inchannels(self):
+ return self.num_inchannels
+
+ def forward(self, x):
+ if self.num_branches == 1:
+ return [self.branches[0](x[0])]
+
+ for i in range(self.num_branches):
+ x[i] = self.branches[i](x[i])
+
+ x_fuse = []
+ for i in range(len(self.fuse_layers)):
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
+ for j in range(1, self.num_branches):
+ if i == j:
+ y = y + x[j]
+ elif j > i:
+ y = y + F.interpolate(
+ self.fuse_layers[i][j](x[j]), size=[x[i].shape[2], x[i].shape[3]], mode="bilinear"
+ )
+ else:
+ y = y + self.fuse_layers[i][j](x[j])
+ x_fuse.append(self.relu(y))
+
+ return x_fuse
+
+
+blocks_dict = {"BASIC": BasicBlock, "BOTTLENECK": Bottleneck}
+
+
+class HighResolutionNet(nn.Module):
+ def __init__(
+ self,
+ ):
+ self.inplanes = 64
+ super(HighResolutionNet, self).__init__()
+
+ # stem net
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
+ self.bn1 = BatchNorm2d(64, momentum=BN_MOMENTUM)
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
+ self.bn2 = BatchNorm2d(64, momentum=BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=True)
+ self.sf = nn.Softmax(dim=1)
+ self.layer1 = self._make_layer(Bottleneck, 64, 64, 4)
+
+ num_channels = [18, 36]
+ block = blocks_dict["BASIC"]
+ num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
+ self.transition1 = self._make_transition_layer([256], num_channels)
+ config_list = [1, 2, [4, 4], [18, 36], "BASIC", "SUM"]
+ self.stage2, pre_stage_channels = self._make_stage(config_list, num_channels)
+
+ num_channels = [18, 36, 72]
+ block = blocks_dict["BASIC"]
+ num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
+ self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels)
+ config_list = [4, 3, [4, 4, 4], [18, 36, 72], "BASIC", "SUM"]
+ self.stage3, pre_stage_channels = self._make_stage(config_list, num_channels)
+
+ num_channels = [18, 36, 72, 144]
+ block = blocks_dict["BASIC"]
+ num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
+ self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels)
+ config_list = [3, 4, [4, 4, 4, 4], [18, 36, 72, 144], "BASIC", "SUM"]
+ self.stage4, pre_stage_channels = self._make_stage(config_list, num_channels, multi_scale_output=True)
+
+ final_inp_channels = sum(pre_stage_channels)
+
+ self.head = nn.Sequential(
+ nn.Conv2d(
+ in_channels=final_inp_channels, out_channels=final_inp_channels, kernel_size=1, stride=1, padding=0
+ ),
+ BatchNorm2d(final_inp_channels, momentum=BN_MOMENTUM),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(in_channels=final_inp_channels, out_channels=98, kernel_size=1, stride=1, padding=0),
+ )
+
+ def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):
+ num_branches_cur = len(num_channels_cur_layer)
+ num_branches_pre = len(num_channels_pre_layer)
+
+ transition_layers = []
+ for i in range(num_branches_cur):
+ if i < num_branches_pre:
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
+ transition_layers.append(
+ nn.Sequential(
+ nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False),
+ BatchNorm2d(num_channels_cur_layer[i], momentum=BN_MOMENTUM),
+ nn.ReLU(inplace=True),
+ )
+ )
+ else:
+ transition_layers.append(None)
+ else:
+ conv3x3s = []
+ for j in range(i + 1 - num_branches_pre):
+ inchannels = num_channels_pre_layer[-1]
+ outchannels = num_channels_cur_layer[i] if j == i - num_branches_pre else inchannels
+ conv3x3s.append(
+ nn.Sequential(
+ nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False),
+ BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
+ nn.ReLU(inplace=True),
+ )
+ )
+ transition_layers.append(nn.Sequential(*conv3x3s))
+
+ return nn.ModuleList(transition_layers)
+
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
+ BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
+ )
+
+ layers = []
+ layers.append(block(inplanes, planes, stride, downsample))
+ inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def _make_stage(self, config_list, num_inchannels, multi_scale_output=True):
+ num_modules = config_list[0]
+ num_branches = config_list[1]
+ num_blocks = config_list[2]
+ num_channels = config_list[3]
+ block = blocks_dict[config_list[4]]
+ fuse_method = config_list[5]
+
+ modules = []
+ for i in range(num_modules):
+ # multi_scale_output is only used last module
+ if not multi_scale_output and i == num_modules - 1:
+ reset_multi_scale_output = False
+ else:
+ reset_multi_scale_output = True
+ modules.append(
+ HighResolutionModule(
+ num_branches, block, num_blocks, num_inchannels, num_channels, fuse_method, reset_multi_scale_output
+ )
+ )
+ num_inchannels = modules[-1].get_num_inchannels()
+
+ return nn.Sequential(*modules), num_inchannels
+
+ def forward(self, x):
+ # h, w = x.size(2), x.size(3)
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x = self.relu(x)
+ x = self.layer1(x)
+
+ x_list = []
+ for i in range(2):
+ if self.transition1[i] is not None:
+ x_list.append(self.transition1[i](x))
+ else:
+ x_list.append(x)
+ y_list = self.stage2(x_list)
+
+ x_list = []
+ for i in range(3):
+ if self.transition2[i] is not None:
+ x_list.append(self.transition2[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage3(x_list)
+
+ x_list = []
+ for i in range(4):
+ if self.transition3[i] is not None:
+ x_list.append(self.transition3[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ x = self.stage4(x_list)
+
+ # Head Part
+ height, width = x[0].size(2), x[0].size(3)
+ x1 = F.interpolate(x[1], size=(height, width), mode="bilinear", align_corners=False)
+ x2 = F.interpolate(x[2], size=(height, width), mode="bilinear", align_corners=False)
+ x3 = F.interpolate(x[3], size=(height, width), mode="bilinear", align_corners=False)
+ x = torch.cat([x[0], x1, x2, x3], 1)
+ x = self.head(x)
+
+ return x
diff --git a/README.md b/README.md
index c56830a933d9167c4d8380ec864c3d766b35f371..7d028631c20d12639b166f0a3582378220fa64a9 100644
--- a/README.md
+++ b/README.md
@@ -1,11 +1,9 @@
---
-title: HiFiFace Inference Demo
-emoji: 🏆
-colorFrom: indigo
-colorTo: pink
-sdk: gradio
-sdk_version: 3.50.2
-app_file: app.py
+title: HiFiFace Inference
+emoji: 📉
+colorFrom: blue
+colorTo: yellow
+sdk: docker
pinned: false
license: mit
---
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..02f3f1637e6f69a98e639162f360789a3d541a52
--- /dev/null
+++ b/app.py
@@ -0,0 +1,137 @@
+import argparse
+
+import gradio as gr
+
+from benchmark.app_image import ImageSwap
+from benchmark.app_video import VideoSwap
+from configs.train_config import TrainConfig
+from models.model import HifiFace
+
+
+class ConfigPath:
+ face_detector_weights = "./checkpoints/face_detector/face_detector_scrfd_10g_bnkps.onnx"
+ model_path = ""
+ model_idx = 80000
+ ffmpeg_device = "cuda"
+ device = "cuda"
+
+
+def main():
+ cfg = ConfigPath()
+ parser = argparse.ArgumentParser(
+ prog="benchmark", description="What the program does", epilog="Text at the bottom of help"
+ )
+ parser.add_argument("-m", "--model_path", default="./checkpoints/standard_model")
+ parser.add_argument("-i", "--model_idx", default="320000")
+ parser.add_argument("-f", "--ffmpeg_device", default="cpu")
+ parser.add_argument("-d", "--device", default="cpu")
+
+ args = parser.parse_args()
+
+ cfg.model_path = args.model_path
+ cfg.model_idx = int(args.model_idx)
+ cfg.ffmpeg_device = args.ffmpeg_device
+ cfg.device = args.device
+ opt = TrainConfig()
+ checkpoint = (cfg.model_path, cfg.model_idx)
+ model = HifiFace(opt.identity_extractor_config, is_training=False, device=cfg.device, load_checkpoint=checkpoint)
+
+ image_infer = ImageSwap(cfg, model)
+ video_infer = VideoSwap(cfg, model)
+
+ def inference_image(source_face, target_face, shape_rate, id_rate, iterations):
+ return image_infer.inference(source_face, target_face, shape_rate, id_rate, int(iterations))
+
+ def inference_video(source_face, target_video, shape_rate, id_rate, iterations):
+ return video_infer.inference(source_face, target_video, shape_rate, id_rate, int(iterations))
+
+ model_name = cfg.model_path.split("/")[-1] + ":" + f"{cfg.model_idx}"
+ with gr.Blocks(title="FaceSwap") as demo:
+ gr.Markdown(
+ f"""
+ ### model: {model_name}
+ """
+ )
+ with gr.Tab("Image swap"):
+ with gr.Row():
+ source_image = gr.Image(shape=None, label="source image")
+ target_image = gr.Image(shape=None, label="target image")
+ with gr.Row():
+ with gr.Column():
+ structure_sim = gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ value=1.0,
+ step=0.1,
+ label="3d similarity",
+ )
+ id_sim = gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ value=1.0,
+ step=0.1,
+ label="id similarity",
+ )
+ iters = gr.Slider(
+ minimum=1,
+ maximum=10,
+ value=1,
+ step=1,
+ label="iters",
+ )
+ image_btn = gr.Button("image swap")
+ output_image = gr.Image(shape=None, label="Result")
+
+ image_btn.click(
+ fn=inference_image,
+ inputs=[source_image, target_image, structure_sim, id_sim, iters],
+ outputs=output_image,
+ )
+
+ with gr.Tab("Video swap"):
+ with gr.Row():
+ source_image = gr.Image(shape=None, label="source image")
+ target_video = gr.Video(value=None, label="target video")
+ with gr.Row():
+ with gr.Column():
+ structure_sim = gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ value=1.0,
+ step=0.1,
+ label="3d similarity",
+ )
+ id_sim = gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ value=1.0,
+ step=0.1,
+ label="id similarity",
+ )
+ iters = gr.Slider(
+ minimum=1,
+ maximum=10,
+ value=1,
+ step=1,
+ label="iters",
+ )
+ video_btn = gr.Button("video swap")
+ output_video = gr.Video(value=None, label="Result")
+
+ video_btn.click(
+ fn=inference_video,
+ inputs=[
+ source_image,
+ target_video,
+ structure_sim,
+ id_sim,
+ iters,
+ ],
+ outputs=output_video,
+ )
+
+ demo.launch(server_name="0.0.0.0", server_port=7860)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/arcface_torch/README.md b/arcface_torch/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8d391f63684dd1f47900dc6449a5e22fa25e3da3
--- /dev/null
+++ b/arcface_torch/README.md
@@ -0,0 +1,218 @@
+# Distributed Arcface Training in Pytorch
+
+The "arcface_torch" repository is the official implementation of the ArcFace algorithm. It supports distributed and sparse training with multiple distributed training examples, including several memory-saving techniques such as mixed precision training and gradient checkpointing. It also supports training for ViT models and datasets including WebFace42M and Glint360K, two of the largest open-source datasets. Additionally, the repository comes with a built-in tool for converting to ONNX format, making it easy to submit to MFR evaluation systems.
+
+[](https://paperswithcode.com/sota/face-verification-on-ijb-c?p=killing-two-birds-with-one-stone-efficient)
+[](https://paperswithcode.com/sota/face-verification-on-ijb-b?p=killing-two-birds-with-one-stone-efficient)
+[](https://paperswithcode.com/sota/face-verification-on-agedb-30?p=killing-two-birds-with-one-stone-efficient)
+[](https://paperswithcode.com/sota/face-verification-on-cfp-fp?p=killing-two-birds-with-one-stone-efficient)
+
+## Requirements
+
+To avail the latest features of PyTorch, we have upgraded to version 1.12.0.
+
+- Install [PyTorch](https://pytorch.org/get-started/previous-versions/) (torch>=1.12.0).
+- (Optional) Install [DALI](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/), our doc for [install_dali.md](docs/install_dali.md).
+- `pip install -r requirement.txt`.
+
+## How to Training
+
+To train a model, execute the `train.py` script with the path to the configuration files. The sample commands provided below demonstrate the process of conducting distributed training.
+
+### 1. To run on one GPU:
+
+```shell
+python train_v2.py configs/ms1mv3_r50_onegpu
+```
+
+Note:
+It is not recommended to use a single GPU for training, as this may result in longer training times and suboptimal performance. For best results, we suggest using multiple GPUs or a GPU cluster.
+
+
+### 2. To run on a machine with 8 GPUs:
+
+```shell
+torchrun --nproc_per_node=8 train.py configs/ms1mv3_r50
+```
+
+### 3. To run on 2 machines with 8 GPUs each:
+
+Node 0:
+
+```shell
+torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=12581 train.py configs/wf42m_pfc02_16gpus_r100
+```
+
+Node 1:
+
+```shell
+torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=12581 train.py configs/wf42m_pfc02_16gpus_r100
+```
+
+### 4. Run ViT-B on a machine with 24k batchsize:
+
+```shell
+torchrun --nproc_per_node=8 train_v2.py configs/wf42m_pfc03_40epoch_8gpu_vit_b
+```
+
+
+## Download Datasets or Prepare Datasets
+- [MS1MV2](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_#ms1m-arcface-85k-ids58m-images-57) (87k IDs, 5.8M images)
+- [MS1MV3](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_#ms1m-retinaface) (93k IDs, 5.2M images)
+- [Glint360K](https://github.com/deepinsight/insightface/tree/master/recognition/partial_fc#4-download) (360k IDs, 17.1M images)
+- [WebFace42M](docs/prepare_webface42m.md) (2M IDs, 42.5M images)
+- [Your Dataset, Click Here!](docs/prepare_custom_dataset.md)
+
+Note:
+If you want to use DALI for data reading, please use the script 'scripts/shuffle_rec.py' to shuffle the InsightFace style rec before using it.
+Example:
+
+`python scripts/shuffle_rec.py ms1m-retinaface-t1`
+
+You will get the "shuffled_ms1m-retinaface-t1" folder, where the samples in the "train.rec" file are shuffled.
+
+
+## Model Zoo
+
+- The models are available for non-commercial research purposes only.
+- All models can be found in here.
+- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
+- [OneDrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
+
+### Performance on IJB-C and [**ICCV2021-MFR**](https://github.com/deepinsight/insightface/blob/master/challenges/mfr/README.md)
+
+ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face
+recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities.
+As the result, we can evaluate the FAIR performance for different algorithms.
+
+For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The
+globalised multi-racial testset contains 242,143 identities and 1,624,305 images.
+
+
+#### 1. Training on Single-Host GPU
+
+| Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | log |
+|:---------------|:--------------------|:------------|:------------|:------------|:------------------------------------------------------------------------------------------------------------------------------------|
+| MS1MV2 | mobilefacenet-0.45G | 62.07 | 93.61 | 90.28 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv2_mbf/training.log) |
+| MS1MV2 | r50 | 75.13 | 95.97 | 94.07 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv2_r50/training.log) |
+| MS1MV2 | r100 | 78.12 | 96.37 | 94.27 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv2_r100/training.log) |
+| MS1MV3 | mobilefacenet-0.45G | 63.78 | 94.23 | 91.33 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_mbf/training.log) |
+| MS1MV3 | r50 | 79.14 | 96.37 | 94.47 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_r50/training.log) |
+| MS1MV3 | r100 | 81.97 | 96.85 | 95.02 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_r100/training.log) |
+| Glint360K | mobilefacenet-0.45G | 70.18 | 95.04 | 92.62 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_mbf/training.log) |
+| Glint360K | r50 | 86.34 | 97.16 | 95.81 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_r50/training.log) |
+| Glint360k | r100 | 89.52 | 97.55 | 96.38 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_r100/training.log) |
+| WF4M | r100 | 89.87 | 97.19 | 95.48 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf4m_r100/training.log) |
+| WF12M-PFC-0.2 | r100 | 94.75 | 97.60 | 95.90 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf12m_pfc02_r100/training.log) |
+| WF12M-PFC-0.3 | r100 | 94.71 | 97.64 | 96.01 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf12m_pfc03_r100/training.log) |
+| WF12M | r100 | 94.69 | 97.59 | 95.97 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf12m_r100/training.log) |
+| WF42M-PFC-0.2 | r100 | 96.27 | 97.70 | 96.31 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf42m_pfc02_r100/training.log) |
+| WF42M-PFC-0.2 | ViT-T-1.5G | 92.04 | 97.27 | 95.68 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf42m_pfc02_40epoch_8gpu_vit_t/training.log) |
+| WF42M-PFC-0.3 | ViT-B-11G | 97.16 | 97.91 | 97.05 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_b_8gpu/training.log) |
+
+#### 2. Training on Multi-Host GPU
+
+| Datasets | Backbone(bs*gpus) | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | Throughout | log |
+|:-----------------|:------------------|:------------|:------------|:------------|:-----------|:-------------------------------------------------------------------------------------------------------------------------------------------|
+| WF42M-PFC-0.2 | r50(512*8) | 93.83 | 97.53 | 96.16 | ~5900 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r50_bs4k_pfc02/training.log) |
+| WF42M-PFC-0.2 | r50(512*16) | 93.96 | 97.46 | 96.12 | ~11000 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r50_lr01_pfc02_bs8k_16gpus/training.log) |
+| WF42M-PFC-0.2 | r50(128*32) | 94.04 | 97.48 | 95.94 | ~17000 | click me |
+| WF42M-PFC-0.2 | r100(128*16) | 96.28 | 97.80 | 96.57 | ~5200 | click me |
+| WF42M-PFC-0.2 | r100(256*16) | 96.69 | 97.85 | 96.63 | ~5200 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r100_bs4k_pfc02/training.log) |
+| WF42M-PFC-0.0018 | r100(512*32) | 93.08 | 97.51 | 95.88 | ~10000 | click me |
+| WF42M-PFC-0.2 | r100(128*32) | 96.57 | 97.83 | 96.50 | ~9800 | click me |
+
+`r100(128*32)` means backbone is r100, batchsize per gpu is 128, the number of gpus is 32.
+
+
+
+#### 3. ViT For Face Recognition
+
+| Datasets | Backbone(bs) | FLOPs | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | Throughout | log |
+|:--------------|:--------------|:------|:------------|:------------|:------------|:-----------|:-----------------------------------------------------------------------------------------------------------------------------|
+| WF42M-PFC-0.3 | r18(128*32) | 2.6 | 79.13 | 95.77 | 93.36 | - | click me |
+| WF42M-PFC-0.3 | r50(128*32) | 6.3 | 94.03 | 97.48 | 95.94 | - | click me |
+| WF42M-PFC-0.3 | r100(128*32) | 12.1 | 96.69 | 97.82 | 96.45 | - | click me |
+| WF42M-PFC-0.3 | r200(128*32) | 23.5 | 97.70 | 97.97 | 96.93 | - | click me |
+| WF42M-PFC-0.3 | VIT-T(384*64) | 1.5 | 92.24 | 97.31 | 95.97 | ~35000 | click me |
+| WF42M-PFC-0.3 | VIT-S(384*64) | 5.7 | 95.87 | 97.73 | 96.57 | ~25000 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_s_64gpu/training.log) |
+| WF42M-PFC-0.3 | VIT-B(384*64) | 11.4 | 97.42 | 97.90 | 97.04 | ~13800 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_b_64gpu/training.log) |
+| WF42M-PFC-0.3 | VIT-L(384*64) | 25.3 | 97.85 | 98.00 | 97.23 | ~9406 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_l_64gpu/training.log) |
+
+`WF42M` means WebFace42M, `PFC-0.3` means negivate class centers sample rate is 0.3.
+
+#### 4. Noisy Datasets
+
+| Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | log |
+|:-------------------------|:---------|:------------|:------------|:------------|:---------|
+| WF12M-Flip(40%) | r50 | 43.87 | 88.35 | 80.78 | click me |
+| WF12M-Flip(40%)-PFC-0.1* | r50 | 80.20 | 96.11 | 93.79 | click me |
+| WF12M-Conflict | r50 | 79.93 | 95.30 | 91.56 | click me |
+| WF12M-Conflict-PFC-0.3* | r50 | 91.68 | 97.28 | 95.75 | click me |
+
+`WF12M` means WebFace12M, `+PFC-0.1*` denotes additional abnormal inter-class filtering.
+
+
+
+## Speed Benchmark
+
+
+
+**Arcface-Torch** is an efficient tool for training large-scale face recognition training sets. When the number of classes in the training sets exceeds one million, the partial FC sampling strategy maintains the same accuracy while providing several times faster training performance and lower GPU memory utilization. The partial FC is a sparse variant of the model parallel architecture for large-scale face recognition, utilizing a sparse softmax that dynamically samples a subset of class centers for each training batch. During each iteration, only a sparse portion of the parameters are updated, leading to a significant reduction in GPU memory requirements and computational demands. With the partial FC approach, it is possible to train sets with up to 29 million identities, the largest to date. Furthermore, the partial FC method supports multi-machine distributed training and mixed precision training.
+
+
+
+More details see
+[speed_benchmark.md](docs/speed_benchmark.md) in docs.
+
+> 1. Training Speed of Various Parallel Techniques (Samples per Second) on a Tesla V100 32GB x 8 System (Higher is Optimal)
+
+`-` means training failed because of gpu memory limitations.
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+|:--------------------------------|:--------------|:---------------|:---------------|
+| 125000 | 4681 | 4824 | 5004 |
+| 1400000 | **1672** | 3043 | 4738 |
+| 5500000 | **-** | **1389** | 3975 |
+| 8000000 | **-** | **-** | 3565 |
+| 16000000 | **-** | **-** | 2679 |
+| 29000000 | **-** | **-** | **1855** |
+
+> 2. GPU Memory Utilization of Various Parallel Techniques (MB per GPU) on a Tesla V100 32GB x 8 System (Lower is Optimal)
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+|:--------------------------------|:--------------|:---------------|:---------------|
+| 125000 | 7358 | 5306 | 4868 |
+| 1400000 | 32252 | 11178 | 6056 |
+| 5500000 | **-** | 32188 | 9854 |
+| 8000000 | **-** | **-** | 12310 |
+| 16000000 | **-** | **-** | 19950 |
+| 29000000 | **-** | **-** | 32324 |
+
+
+## Citations
+
+```
+@inproceedings{deng2019arcface,
+ title={Arcface: Additive angular margin loss for deep face recognition},
+ author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
+ booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={4690--4699},
+ year={2019}
+}
+@inproceedings{An_2022_CVPR,
+ author={An, Xiang and Deng, Jiankang and Guo, Jia and Feng, Ziyong and Zhu, XuHan and Yang, Jing and Liu, Tongliang},
+ title={Killing Two Birds With One Stone: Efficient and Robust Training of Face Recognition CNNs by Partial FC},
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
+ month={June},
+ year={2022},
+ pages={4042-4051}
+}
+@inproceedings{zhu2021webface260m,
+ title={Webface260m: A benchmark unveiling the power of million-scale deep face recognition},
+ author={Zhu, Zheng and Huang, Guan and Deng, Jiankang and Ye, Yun and Huang, Junjie and Chen, Xinze and Zhu, Jiagang and Yang, Tian and Lu, Jiwen and Du, Dalong and Zhou, Jie},
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
+ pages={10492--10502},
+ year={2021}
+}
+```
diff --git a/arcface_torch/backbones/__init__.py b/arcface_torch/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..94288c3af835e3513ddc70eb4cfb7f7e86852e3f
--- /dev/null
+++ b/arcface_torch/backbones/__init__.py
@@ -0,0 +1,157 @@
+from .iresnet import iresnet100
+from .iresnet import iresnet18
+from .iresnet import iresnet200
+from .iresnet import iresnet34
+from .iresnet import iresnet50
+from .mobilefacenet import get_mbf
+
+
+def get_model(name, **kwargs):
+ # resnet
+ if name == "r18":
+ return iresnet18(False, **kwargs)
+ elif name == "r34":
+ return iresnet34(False, **kwargs)
+ elif name == "r50":
+ return iresnet50(False, **kwargs)
+ elif name == "r100":
+ return iresnet100(False, **kwargs)
+ elif name == "r200":
+ return iresnet200(False, **kwargs)
+ elif name == "r2060":
+ from .iresnet2060 import iresnet2060
+
+ return iresnet2060(False, **kwargs)
+
+ elif name == "mbf":
+ fp16 = kwargs.get("fp16", False)
+ num_features = kwargs.get("num_features", 512)
+ return get_mbf(fp16=fp16, num_features=num_features)
+
+ elif name == "mbf_large":
+ from .mobilefacenet import get_mbf_large
+
+ fp16 = kwargs.get("fp16", False)
+ num_features = kwargs.get("num_features", 512)
+ return get_mbf_large(fp16=fp16, num_features=num_features)
+
+ elif name == "vit_t":
+ num_features = kwargs.get("num_features", 512)
+ from .vit import VisionTransformer
+
+ return VisionTransformer(
+ img_size=112,
+ patch_size=9,
+ num_classes=num_features,
+ embed_dim=256,
+ depth=12,
+ num_heads=8,
+ drop_path_rate=0.1,
+ norm_layer="ln",
+ mask_ratio=0.1,
+ )
+
+ elif name == "vit_t_dp005_mask0": # For WebFace42M
+ num_features = kwargs.get("num_features", 512)
+ from .vit import VisionTransformer
+
+ return VisionTransformer(
+ img_size=112,
+ patch_size=9,
+ num_classes=num_features,
+ embed_dim=256,
+ depth=12,
+ num_heads=8,
+ drop_path_rate=0.05,
+ norm_layer="ln",
+ mask_ratio=0.0,
+ )
+
+ elif name == "vit_s":
+ num_features = kwargs.get("num_features", 512)
+ from .vit import VisionTransformer
+
+ return VisionTransformer(
+ img_size=112,
+ patch_size=9,
+ num_classes=num_features,
+ embed_dim=512,
+ depth=12,
+ num_heads=8,
+ drop_path_rate=0.1,
+ norm_layer="ln",
+ mask_ratio=0.1,
+ )
+
+ elif name == "vit_s_dp005_mask_0": # For WebFace42M
+ num_features = kwargs.get("num_features", 512)
+ from .vit import VisionTransformer
+
+ return VisionTransformer(
+ img_size=112,
+ patch_size=9,
+ num_classes=num_features,
+ embed_dim=512,
+ depth=12,
+ num_heads=8,
+ drop_path_rate=0.05,
+ norm_layer="ln",
+ mask_ratio=0.0,
+ )
+
+ elif name == "vit_b":
+ # this is a feature
+ num_features = kwargs.get("num_features", 512)
+ from .vit import VisionTransformer
+
+ return VisionTransformer(
+ img_size=112,
+ patch_size=9,
+ num_classes=num_features,
+ embed_dim=512,
+ depth=24,
+ num_heads=8,
+ drop_path_rate=0.1,
+ norm_layer="ln",
+ mask_ratio=0.1,
+ using_checkpoint=True,
+ )
+
+ elif name == "vit_b_dp005_mask_005": # For WebFace42M
+ # this is a feature
+ num_features = kwargs.get("num_features", 512)
+ from .vit import VisionTransformer
+
+ return VisionTransformer(
+ img_size=112,
+ patch_size=9,
+ num_classes=num_features,
+ embed_dim=512,
+ depth=24,
+ num_heads=8,
+ drop_path_rate=0.05,
+ norm_layer="ln",
+ mask_ratio=0.05,
+ using_checkpoint=True,
+ )
+
+ elif name == "vit_l_dp005_mask_005": # For WebFace42M
+ # this is a feature
+ num_features = kwargs.get("num_features", 512)
+ from .vit import VisionTransformer
+
+ return VisionTransformer(
+ img_size=112,
+ patch_size=9,
+ num_classes=num_features,
+ embed_dim=768,
+ depth=24,
+ num_heads=8,
+ drop_path_rate=0.05,
+ norm_layer="ln",
+ mask_ratio=0.05,
+ using_checkpoint=True,
+ )
+
+ else:
+ raise ValueError()
diff --git a/arcface_torch/backbones/__pycache__/__init__.cpython-310.pyc b/arcface_torch/backbones/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6b88a8c7f0d678d6a3991193870e7fce4d217df7
Binary files /dev/null and b/arcface_torch/backbones/__pycache__/__init__.cpython-310.pyc differ
diff --git a/arcface_torch/backbones/__pycache__/iresnet.cpython-310.pyc b/arcface_torch/backbones/__pycache__/iresnet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cee5104b72dea2fe4d36712c760ba92e20e6d8b6
Binary files /dev/null and b/arcface_torch/backbones/__pycache__/iresnet.cpython-310.pyc differ
diff --git a/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-310.pyc b/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1b444fe822f926df2031431d0162891482c7c490
Binary files /dev/null and b/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-310.pyc differ
diff --git a/arcface_torch/backbones/iresnet.py b/arcface_torch/backbones/iresnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c3eea3ac6c1c92a9a92dab3518630cb5039bdf8
--- /dev/null
+++ b/arcface_torch/backbones/iresnet.py
@@ -0,0 +1,198 @@
+import torch
+from torch import nn
+from torch.utils.checkpoint import checkpoint
+
+__all__ = ["iresnet18", "iresnet34", "iresnet50", "iresnet100", "iresnet200"]
+using_ckpt = False
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ groups=groups,
+ bias=False,
+ dilation=dilation,
+ )
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class IBasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1):
+ super(IBasicBlock, self).__init__()
+ if groups != 1 or base_width != 64:
+ raise ValueError("BasicBlock only supports groups=1 and base_width=64")
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ self.bn1 = nn.BatchNorm2d(
+ inplanes,
+ eps=1e-05,
+ )
+ self.conv1 = conv3x3(inplanes, planes)
+ self.bn2 = nn.BatchNorm2d(
+ planes,
+ eps=1e-05,
+ )
+ self.prelu = nn.PReLU(planes)
+ self.conv2 = conv3x3(planes, planes, stride)
+ self.bn3 = nn.BatchNorm2d(
+ planes,
+ eps=1e-05,
+ )
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward_impl(self, x):
+ identity = x
+ out = self.bn1(x)
+ out = self.conv1(out)
+ out = self.bn2(out)
+ out = self.prelu(out)
+ out = self.conv2(out)
+ out = self.bn3(out)
+ if self.downsample is not None:
+ identity = self.downsample(x)
+ out += identity
+ return out
+
+ def forward(self, x):
+ if self.training and using_ckpt:
+ return checkpoint(self.forward_impl, x)
+ else:
+ return self.forward_impl(x)
+
+
+class IResNet(nn.Module):
+ fc_scale = 7 * 7
+
+ def __init__(
+ self,
+ block,
+ layers,
+ dropout=0,
+ num_features=512,
+ zero_init_residual=False,
+ groups=1,
+ width_per_group=64,
+ replace_stride_with_dilation=None,
+ fp16=False,
+ ):
+ super(IResNet, self).__init__()
+ self.extra_gflops = 0.0
+ self.fp16 = fp16
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError(
+ "replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
+ )
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
+ self.prelu = nn.PReLU(self.inplanes)
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
+ self.bn2 = nn.BatchNorm2d(
+ 512 * block.expansion,
+ eps=1e-05,
+ )
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
+ nn.init.constant_(self.features.weight, 1.0)
+ self.features.weight.requires_grad = False
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.normal_(m.weight, 0, 0.1)
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, IBasicBlock):
+ nn.init.constant_(m.bn2.weight, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ nn.BatchNorm2d(
+ planes * block.expansion,
+ eps=1e-05,
+ ),
+ )
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(
+ block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation)
+ )
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ with torch.cuda.amp.autocast(self.fp16):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.prelu(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.bn2(x)
+ x = torch.flatten(x, 1)
+ x = self.dropout(x)
+ x = self.fc(x.float() if self.fp16 else x)
+ x = self.features(x)
+ return x
+
+
+def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
+ model = IResNet(block, layers, **kwargs)
+ if pretrained:
+ raise ValueError()
+ return model
+
+
+def iresnet18(pretrained=False, progress=True, **kwargs):
+ return _iresnet("iresnet18", IBasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
+
+
+def iresnet34(pretrained=False, progress=True, **kwargs):
+ return _iresnet("iresnet34", IBasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs)
+
+
+def iresnet50(pretrained=False, progress=True, **kwargs):
+ return _iresnet("iresnet50", IBasicBlock, [3, 4, 14, 3], pretrained, progress, **kwargs)
+
+
+def iresnet100(pretrained=False, progress=True, **kwargs):
+ return _iresnet("iresnet100", IBasicBlock, [3, 13, 30, 3], pretrained, progress, **kwargs)
+
+
+def iresnet200(pretrained=False, progress=True, **kwargs):
+ return _iresnet("iresnet200", IBasicBlock, [6, 26, 60, 6], pretrained, progress, **kwargs)
diff --git a/arcface_torch/backbones/iresnet2060.py b/arcface_torch/backbones/iresnet2060.py
new file mode 100644
index 0000000000000000000000000000000000000000..468b00201a06e33653f1e0aa738668cf4ee68fb0
--- /dev/null
+++ b/arcface_torch/backbones/iresnet2060.py
@@ -0,0 +1,182 @@
+import torch
+from torch import nn
+
+assert torch.__version__ >= "1.8.1"
+from torch.utils.checkpoint import checkpoint_sequential
+
+__all__ = ["iresnet2060"]
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ groups=groups,
+ bias=False,
+ dilation=dilation,
+ )
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class IBasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1):
+ super(IBasicBlock, self).__init__()
+ if groups != 1 or base_width != 64:
+ raise ValueError("BasicBlock only supports groups=1 and base_width=64")
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ self.bn1 = nn.BatchNorm2d(
+ inplanes,
+ eps=1e-05,
+ )
+ self.conv1 = conv3x3(inplanes, planes)
+ self.bn2 = nn.BatchNorm2d(
+ planes,
+ eps=1e-05,
+ )
+ self.prelu = nn.PReLU(planes)
+ self.conv2 = conv3x3(planes, planes, stride)
+ self.bn3 = nn.BatchNorm2d(
+ planes,
+ eps=1e-05,
+ )
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+ out = self.bn1(x)
+ out = self.conv1(out)
+ out = self.bn2(out)
+ out = self.prelu(out)
+ out = self.conv2(out)
+ out = self.bn3(out)
+ if self.downsample is not None:
+ identity = self.downsample(x)
+ out += identity
+ return out
+
+
+class IResNet(nn.Module):
+ fc_scale = 7 * 7
+
+ def __init__(
+ self,
+ block,
+ layers,
+ dropout=0,
+ num_features=512,
+ zero_init_residual=False,
+ groups=1,
+ width_per_group=64,
+ replace_stride_with_dilation=None,
+ fp16=False,
+ ):
+ super(IResNet, self).__init__()
+ self.fp16 = fp16
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError(
+ "replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
+ )
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
+ self.prelu = nn.PReLU(self.inplanes)
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
+ self.bn2 = nn.BatchNorm2d(
+ 512 * block.expansion,
+ eps=1e-05,
+ )
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
+ nn.init.constant_(self.features.weight, 1.0)
+ self.features.weight.requires_grad = False
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.normal_(m.weight, 0, 0.1)
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, IBasicBlock):
+ nn.init.constant_(m.bn2.weight, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ nn.BatchNorm2d(
+ planes * block.expansion,
+ eps=1e-05,
+ ),
+ )
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(
+ block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation)
+ )
+
+ return nn.Sequential(*layers)
+
+ def checkpoint(self, func, num_seg, x):
+ if self.training:
+ return checkpoint_sequential(func, num_seg, x)
+ else:
+ return func(x)
+
+ def forward(self, x):
+ with torch.cuda.amp.autocast(self.fp16):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.prelu(x)
+ x = self.layer1(x)
+ x = self.checkpoint(self.layer2, 20, x)
+ x = self.checkpoint(self.layer3, 100, x)
+ x = self.layer4(x)
+ x = self.bn2(x)
+ x = torch.flatten(x, 1)
+ x = self.dropout(x)
+ x = self.fc(x.float() if self.fp16 else x)
+ x = self.features(x)
+ return x
+
+
+def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
+ model = IResNet(block, layers, **kwargs)
+ if pretrained:
+ raise ValueError()
+ return model
+
+
+def iresnet2060(pretrained=False, progress=True, **kwargs):
+ return _iresnet("iresnet2060", IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)
diff --git a/arcface_torch/backbones/mobilefacenet.py b/arcface_torch/backbones/mobilefacenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..e36953e8172aa7cdbd58decbf1414c061459526d
--- /dev/null
+++ b/arcface_torch/backbones/mobilefacenet.py
@@ -0,0 +1,160 @@
+"""
+Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
+Original author cavalleria
+"""
+import torch
+import torch.nn as nn
+from torch.nn import BatchNorm1d
+from torch.nn import BatchNorm2d
+from torch.nn import Conv2d
+from torch.nn import Linear
+from torch.nn import Module
+from torch.nn import PReLU
+from torch.nn import Sequential
+
+
+class Flatten(Module):
+ def forward(self, x):
+ return x.view(x.size(0), -1)
+
+
+class ConvBlock(Module):
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
+ super(ConvBlock, self).__init__()
+ self.layers = nn.Sequential(
+ Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
+ BatchNorm2d(num_features=out_c),
+ PReLU(num_parameters=out_c),
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class LinearBlock(Module):
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
+ super(LinearBlock, self).__init__()
+ self.layers = nn.Sequential(
+ Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False), BatchNorm2d(num_features=out_c)
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class DepthWise(Module):
+ def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
+ super(DepthWise, self).__init__()
+ self.residual = residual
+ self.layers = nn.Sequential(
+ ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
+ ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
+ LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
+ )
+
+ def forward(self, x):
+ short_cut = None
+ if self.residual:
+ short_cut = x
+ x = self.layers(x)
+ if self.residual:
+ output = short_cut + x
+ else:
+ output = x
+ return output
+
+
+class Residual(Module):
+ def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
+ super(Residual, self).__init__()
+ modules = []
+ for _ in range(num_block):
+ modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
+ self.layers = Sequential(*modules)
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class GDC(Module):
+ def __init__(self, embedding_size):
+ super(GDC, self).__init__()
+ self.layers = nn.Sequential(
+ LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
+ Flatten(),
+ Linear(512, embedding_size, bias=False),
+ BatchNorm1d(embedding_size),
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class MobileFaceNet(Module):
+ def __init__(self, fp16=False, num_features=512, blocks=(1, 4, 6, 2), scale=2):
+ super(MobileFaceNet, self).__init__()
+ self.scale = scale
+ self.fp16 = fp16
+ self.layers = nn.ModuleList()
+ self.layers.append(ConvBlock(3, 64 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)))
+ if blocks[0] == 1:
+ self.layers.append(
+ ConvBlock(64 * self.scale, 64 * self.scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
+ )
+ else:
+ self.layers.append(
+ Residual(
+ 64 * self.scale, num_block=blocks[0], groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)
+ ),
+ )
+
+ self.layers.extend(
+ [
+ DepthWise(64 * self.scale, 64 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
+ Residual(
+ 64 * self.scale, num_block=blocks[1], groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)
+ ),
+ DepthWise(64 * self.scale, 128 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
+ Residual(
+ 128 * self.scale, num_block=blocks[2], groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)
+ ),
+ DepthWise(128 * self.scale, 128 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
+ Residual(
+ 128 * self.scale, num_block=blocks[3], groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)
+ ),
+ ]
+ )
+
+ self.conv_sep = ConvBlock(128 * self.scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
+ self.features = GDC(num_features)
+ self._initialize_weights()
+
+ def _initialize_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Linear):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, x):
+ with torch.cuda.amp.autocast(self.fp16):
+ for func in self.layers:
+ x = func(x)
+ x = self.conv_sep(x.float() if self.fp16 else x)
+ x = self.features(x)
+ return x
+
+
+def get_mbf(fp16, num_features, blocks=(1, 4, 6, 2), scale=2):
+ return MobileFaceNet(fp16, num_features, blocks, scale=scale)
+
+
+def get_mbf_large(fp16, num_features, blocks=(2, 8, 12, 4), scale=4):
+ return MobileFaceNet(fp16, num_features, blocks, scale=scale)
diff --git a/arcface_torch/backbones/vit.py b/arcface_torch/backbones/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..59bb28b60309c0e904adb0830ffcf265d9f1dec6
--- /dev/null
+++ b/arcface_torch/backbones/vit.py
@@ -0,0 +1,302 @@
+from typing import Callable
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from timm.models.layers import DropPath
+from timm.models.layers import to_2tuple
+from timm.models.layers import trunc_normal_
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.0):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class VITBatchNorm(nn.Module):
+ def __init__(self, num_features):
+ super().__init__()
+ self.num_features = num_features
+ self.bn = nn.BatchNorm1d(num_features=num_features)
+
+ def forward(self, x):
+ return self.bn(x)
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_scale: Optional[None] = None,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+
+ with torch.cuda.amp.autocast(True):
+ batch_size, num_token, embed_dim = x.shape
+ # qkv is [3,batch_size,num_heads,num_token, embed_dim//num_heads]
+ qkv = (
+ self.qkv(x)
+ .reshape(batch_size, num_token, 3, self.num_heads, embed_dim // self.num_heads)
+ .permute(2, 0, 3, 1, 4)
+ )
+ with torch.cuda.amp.autocast(False):
+ q, k, v = qkv[0].float(), qkv[1].float(), qkv[2].float()
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2).reshape(batch_size, num_token, embed_dim)
+ with torch.cuda.amp.autocast(True):
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ num_patches: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ qk_scale: Optional[None] = None,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ drop_path: float = 0.0,
+ act_layer: Callable = nn.ReLU6,
+ norm_layer: str = "ln",
+ patch_n: int = 144,
+ ):
+ super().__init__()
+
+ if norm_layer == "bn":
+ self.norm1 = VITBatchNorm(num_features=num_patches)
+ self.norm2 = VITBatchNorm(num_features=num_patches)
+ elif norm_layer == "ln":
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop
+ )
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+ self.extra_gflops = (num_heads * patch_n * (dim // num_heads) * patch_n * 2) / (1000**3)
+
+ def forward(self, x):
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ with torch.cuda.amp.autocast(True):
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Module):
+ def __init__(self, img_size=108, patch_size=9, in_channels=3, embed_dim=768):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+ self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x):
+ batch_size, channels, height, width = x.shape
+ assert (
+ height == self.img_size[0] and width == self.img_size[1]
+ ), f"Input image size ({height}*{width}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ return x
+
+
+class VisionTransformer(nn.Module):
+ """Vision Transformer with support for patch or hybrid CNN input stage"""
+
+ def __init__(
+ self,
+ img_size: int = 112,
+ patch_size: int = 16,
+ in_channels: int = 3,
+ num_classes: int = 1000,
+ embed_dim: int = 768,
+ depth: int = 12,
+ num_heads: int = 12,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ qk_scale: Optional[None] = None,
+ drop_rate: float = 0.0,
+ attn_drop_rate: float = 0.0,
+ drop_path_rate: float = 0.0,
+ hybrid_backbone: Optional[None] = None,
+ norm_layer: str = "ln",
+ mask_ratio=0.1,
+ using_checkpoint=False,
+ ):
+ super().__init__()
+ self.num_classes = num_classes
+ # num_features for consistency with other models
+ self.num_features = self.embed_dim = embed_dim
+
+ if hybrid_backbone is not None:
+ raise ValueError
+ else:
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim
+ )
+ self.mask_ratio = mask_ratio
+ self.using_checkpoint = using_checkpoint
+ num_patches = self.patch_embed.num_patches
+ self.num_patches = num_patches
+
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth decay rule
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
+ patch_n = (img_size // patch_size) ** 2
+ self.blocks = nn.ModuleList(
+ [
+ Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ num_patches=num_patches,
+ patch_n=patch_n,
+ )
+ for i in range(depth)
+ ]
+ )
+ self.extra_gflops = 0.0
+ for _block in self.blocks:
+ self.extra_gflops += _block.extra_gflops
+
+ if norm_layer == "ln":
+ self.norm = nn.LayerNorm(embed_dim)
+ elif norm_layer == "bn":
+ self.norm = VITBatchNorm(self.num_patches)
+
+ # features head
+ self.feature = nn.Sequential(
+ nn.Linear(in_features=embed_dim * num_patches, out_features=embed_dim, bias=False),
+ nn.BatchNorm1d(num_features=embed_dim, eps=2e-5),
+ nn.Linear(in_features=embed_dim, out_features=num_classes, bias=False),
+ nn.BatchNorm1d(num_features=num_classes, eps=2e-5),
+ )
+
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ torch.nn.init.normal_(self.mask_token, std=0.02)
+ trunc_normal_(self.pos_embed, std=0.02)
+ # trunc_normal_(self.cls_token, std=.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {"pos_embed", "cls_token"}
+
+ def get_classifier(self):
+ return self.head
+
+ def random_masking(self, x, mask_ratio=0.1):
+ """
+ Perform per-sample random masking by per-sample shuffling.
+ Per-sample shuffling is done by argsort random noise.
+ x: [N, L, D], sequence
+ """
+ N, L, D = x.size() # batch, length, dim
+ len_keep = int(L * (1 - mask_ratio))
+
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
+
+ # sort noise for each sample
+ # ascend: small is keep, large is remove
+ ids_shuffle = torch.argsort(noise, dim=1)
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
+
+ # keep the first subset
+ ids_keep = ids_shuffle[:, :len_keep]
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
+
+ # generate the binary mask: 0 is keep, 1 is remove
+ mask = torch.ones([N, L], device=x.device)
+ mask[:, :len_keep] = 0
+ # unshuffle to get the binary mask
+ mask = torch.gather(mask, dim=1, index=ids_restore)
+
+ return x_masked, mask, ids_restore
+
+ def forward_features(self, x):
+ B = x.shape[0]
+ x = self.patch_embed(x)
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ if self.training and self.mask_ratio > 0:
+ x, _, ids_restore = self.random_masking(x)
+
+ for func in self.blocks:
+ if self.using_checkpoint and self.training:
+ from torch.utils.checkpoint import checkpoint
+
+ x = checkpoint(func, x)
+ else:
+ x = func(x)
+ x = self.norm(x.float())
+
+ if self.training and self.mask_ratio > 0:
+ mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
+ x_ = torch.cat([x[:, :, :], mask_tokens], dim=1) # no cls token
+ x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
+ x = x_
+ return torch.reshape(x, (B, self.num_patches * self.embed_dim))
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.feature(x)
+ return x
diff --git a/arcface_torch/configs/3millions.py b/arcface_torch/configs/3millions.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b110223a00e6c1975709eb58968766792f18dd1
--- /dev/null
+++ b/arcface_torch/configs/3millions.py
@@ -0,0 +1,23 @@
+from easydict import EasyDict as edict
+
+# configs for test speed
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.1
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 512 # total_batch_size = batch_size * num_gpus
+config.lr = 0.1 # batch size is 512
+
+config.rec = "synthetic"
+config.num_classes = 30 * 10000
+config.num_image = 100000
+config.num_epoch = 30
+config.warmup_epoch = -1
+config.val_targets = []
diff --git a/arcface_torch/configs/__init__.py b/arcface_torch/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/arcface_torch/configs/base.py b/arcface_torch/configs/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7c30bec70a7173114e8b29e492cbc483ab55a6c
--- /dev/null
+++ b/arcface_torch/configs/base.py
@@ -0,0 +1,59 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+
+# Margin Base Softmax
+config.margin_list = (1.0, 0.5, 0.0)
+config.network = "r50"
+config.resume = False
+config.save_all_states = False
+config.output = "ms1mv3_arcface_r50"
+
+config.embedding_size = 512
+
+# Partial FC
+config.sample_rate = 1
+config.interclass_filtering_threshold = 0
+
+config.fp16 = False
+config.batch_size = 128
+
+# For SGD
+config.optimizer = "sgd"
+config.lr = 0.1
+config.momentum = 0.9
+config.weight_decay = 5e-4
+
+# For AdamW
+# config.optimizer = "adamw"
+# config.lr = 0.001
+# config.weight_decay = 0.1
+
+config.verbose = 2000
+config.frequent = 10
+
+# For Large Sacle Dataset, such as WebFace42M
+config.dali = False
+
+# Gradient ACC
+config.gradient_acc = 1
+
+# setup seed
+config.seed = 2048
+
+# dataload numworkers
+config.num_workers = 2
+
+# WandB Logger
+config.wandb_key = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
+config.suffix_run_name = None
+config.using_wandb = False
+config.wandb_entity = "entity"
+config.wandb_project = "project"
+config.wandb_log_all = True
+config.save_artifacts = False
+config.wandb_resume = False # resume wandb run: Only if the you wand t resume the last run that it was interrupted
diff --git a/arcface_torch/configs/glint360k_mbf.py b/arcface_torch/configs/glint360k_mbf.py
new file mode 100644
index 0000000000000000000000000000000000000000..03447e982487f19c40c814448f9fdfea6c306b0f
--- /dev/null
+++ b/arcface_torch/configs/glint360k_mbf.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 1e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/arcface_torch/configs/glint360k_r100.py b/arcface_torch/configs/glint360k_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d6676e1f92d2f2d2c7f5ef5b0d03f18311d0b48
--- /dev/null
+++ b/arcface_torch/configs/glint360k_r100.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 1e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/arcface_torch/configs/glint360k_r50.py b/arcface_torch/configs/glint360k_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..46bd79b92986294ff5cb1f53afc41f8b07e5dc08
--- /dev/null
+++ b/arcface_torch/configs/glint360k_r50.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 1e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/arcface_torch/configs/ms1mv2_mbf.py b/arcface_torch/configs/ms1mv2_mbf.py
new file mode 100644
index 0000000000000000000000000000000000000000..098afd8d2d6ca353d0b02281d02ac54e584f8281
--- /dev/null
+++ b/arcface_torch/configs/ms1mv2_mbf.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.5, 0.0)
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 1e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/faces_emore"
+config.num_classes = 85742
+config.num_image = 5822653
+config.num_epoch = 40
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/arcface_torch/configs/ms1mv2_r100.py b/arcface_torch/configs/ms1mv2_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..24fd0417f2219e63e91fdbc92c609ebc596cee21
--- /dev/null
+++ b/arcface_torch/configs/ms1mv2_r100.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.5, 0.0)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/faces_emore"
+config.num_classes = 85742
+config.num_image = 5822653
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/arcface_torch/configs/ms1mv2_r50.py b/arcface_torch/configs/ms1mv2_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..236721a526489b2cac7ba66a22bfc3d650e744cd
--- /dev/null
+++ b/arcface_torch/configs/ms1mv2_r50.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.5, 0.0)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/faces_emore"
+config.num_classes = 85742
+config.num_image = 5822653
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/arcface_torch/configs/ms1mv3_mbf.py b/arcface_torch/configs/ms1mv3_mbf.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb093f42440a0cb0c3bfdf7172f7e2fa478619c7
--- /dev/null
+++ b/arcface_torch/configs/ms1mv3_mbf.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.5, 0.0)
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 1e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 40
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/arcface_torch/configs/ms1mv3_r100.py b/arcface_torch/configs/ms1mv3_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..98263fc00dd15f0ad99c2a24a398433fa1c563f8
--- /dev/null
+++ b/arcface_torch/configs/ms1mv3_r100.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.5, 0.0)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/arcface_torch/configs/ms1mv3_r50.py b/arcface_torch/configs/ms1mv3_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef1a4b5d7eebf5df9a7340e07a003450fd1df976
--- /dev/null
+++ b/arcface_torch/configs/ms1mv3_r50.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.5, 0.0)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/arcface_torch/configs/ms1mv3_r50_onegpu.py b/arcface_torch/configs/ms1mv3_r50_onegpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..06e2e92ac44d92d76682dd083afd57920516e229
--- /dev/null
+++ b/arcface_torch/configs/ms1mv3_r50_onegpu.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.5, 0.0)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.02
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/arcface_torch/configs/wf12m_conflict_r50.py b/arcface_torch/configs/wf12m_conflict_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..de94fcb32cad796bda63521e4f81a4f7fe88923b
--- /dev/null
+++ b/arcface_torch/configs/wf12m_conflict_r50.py
@@ -0,0 +1,28 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.interclass_filtering_threshold = 0
+config.fp16 = True
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.optimizer = "sgd"
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace12M_Conflict"
+config.num_classes = 1017970
+config.num_image = 12720066
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/arcface_torch/configs/wf12m_conflict_r50_pfc03_filter04.py b/arcface_torch/configs/wf12m_conflict_r50_pfc03_filter04.py
new file mode 100644
index 0000000000000000000000000000000000000000..a766f4154bb801b57d0f9519748b63941e349330
--- /dev/null
+++ b/arcface_torch/configs/wf12m_conflict_r50_pfc03_filter04.py
@@ -0,0 +1,28 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.interclass_filtering_threshold = 0.4
+config.fp16 = True
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.optimizer = "sgd"
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace12M_Conflict"
+config.num_classes = 1017970
+config.num_image = 12720066
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/arcface_torch/configs/wf12m_flip_pfc01_filter04_r50.py b/arcface_torch/configs/wf12m_flip_pfc01_filter04_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c1018b7f0d0320678b33b212eed5751badf72ee
--- /dev/null
+++ b/arcface_torch/configs/wf12m_flip_pfc01_filter04_r50.py
@@ -0,0 +1,28 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.1
+config.interclass_filtering_threshold = 0.4
+config.fp16 = True
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.optimizer = "sgd"
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace12M_FLIP40"
+config.num_classes = 617970
+config.num_image = 12720066
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/arcface_torch/configs/wf12m_flip_r50.py b/arcface_torch/configs/wf12m_flip_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..fde56fed6d8513b95882b7701f93f8574afbca9c
--- /dev/null
+++ b/arcface_torch/configs/wf12m_flip_r50.py
@@ -0,0 +1,28 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.interclass_filtering_threshold = 0
+config.fp16 = True
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.optimizer = "sgd"
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace12M_FLIP40"
+config.num_classes = 617970
+config.num_image = 12720066
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/arcface_torch/configs/wf12m_mbf.py b/arcface_torch/configs/wf12m_mbf.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1cb93b2f168e3a64e65d1f8d6cf058e41676c6a
--- /dev/null
+++ b/arcface_torch/configs/wf12m_mbf.py
@@ -0,0 +1,28 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.interclass_filtering_threshold = 0
+config.fp16 = True
+config.weight_decay = 1e-4
+config.batch_size = 128
+config.optimizer = "sgd"
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace12M"
+config.num_classes = 617970
+config.num_image = 12720066
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = []
diff --git a/arcface_torch/configs/wf12m_pfc02_r100.py b/arcface_torch/configs/wf12m_pfc02_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..72f0f0ec0ce5c523bace8b7869181ea807e72423
--- /dev/null
+++ b/arcface_torch/configs/wf12m_pfc02_r100.py
@@ -0,0 +1,28 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.2
+config.interclass_filtering_threshold = 0
+config.fp16 = True
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.optimizer = "sgd"
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace12M"
+config.num_classes = 617970
+config.num_image = 12720066
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = []
diff --git a/arcface_torch/configs/wf12m_r100.py b/arcface_torch/configs/wf12m_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..2663dc950c42f699428d92e7349a1cf5ed8d848d
--- /dev/null
+++ b/arcface_torch/configs/wf12m_r100.py
@@ -0,0 +1,28 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.interclass_filtering_threshold = 0
+config.fp16 = True
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.optimizer = "sgd"
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace12M"
+config.num_classes = 617970
+config.num_image = 12720066
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = []
diff --git a/arcface_torch/configs/wf12m_r50.py b/arcface_torch/configs/wf12m_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a7284663d6afbe6f205c8c9f10cd454ef1045ca
--- /dev/null
+++ b/arcface_torch/configs/wf12m_r50.py
@@ -0,0 +1,28 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.interclass_filtering_threshold = 0
+config.fp16 = True
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.optimizer = "sgd"
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace12M"
+config.num_classes = 617970
+config.num_image = 12720066
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = []
diff --git a/arcface_torch/configs/wf42m_pfc0008_32gpu_r100.py b/arcface_torch/configs/wf42m_pfc0008_32gpu_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..2885816cb9b635c526d1d2269c606e93fa54a2e6
--- /dev/null
+++ b/arcface_torch/configs/wf42m_pfc0008_32gpu_r100.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 512
+config.lr = 0.4
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/arcface_torch/configs/wf42m_pfc02_16gpus_mbf_bs8k.py b/arcface_torch/configs/wf42m_pfc02_16gpus_mbf_bs8k.py
new file mode 100644
index 0000000000000000000000000000000000000000..14a6bb79da7eaa3f111e9efedf507e46a953c9aa
--- /dev/null
+++ b/arcface_torch/configs/wf42m_pfc02_16gpus_mbf_bs8k.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.2
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 1e-4
+config.batch_size = 512
+config.lr = 0.4
+config.verbose = 10000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = 2
+config.val_targets = []
diff --git a/arcface_torch/configs/wf42m_pfc02_16gpus_r100.py b/arcface_torch/configs/wf42m_pfc02_16gpus_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..035684732003b5c7b8fe8ea34e097bd22fbcca37
--- /dev/null
+++ b/arcface_torch/configs/wf42m_pfc02_16gpus_r100.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.2
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 256
+config.lr = 0.3
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = 1
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/arcface_torch/configs/wf42m_pfc02_16gpus_r50_bs8k.py b/arcface_torch/configs/wf42m_pfc02_16gpus_r50_bs8k.py
new file mode 100644
index 0000000000000000000000000000000000000000..c02bdf3afe8370086cf64fd112244b00cee35a6f
--- /dev/null
+++ b/arcface_torch/configs/wf42m_pfc02_16gpus_r50_bs8k.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.2
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 512
+config.lr = 0.6
+config.verbose = 10000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = 4
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/arcface_torch/configs/wf42m_pfc02_32gpus_r50_bs4k.py b/arcface_torch/configs/wf42m_pfc02_32gpus_r50_bs4k.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e8407943ffef4ae3ee02ddb3f2361a9ac655cbb
--- /dev/null
+++ b/arcface_torch/configs/wf42m_pfc02_32gpus_r50_bs4k.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.2
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.4
+config.verbose = 10000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = 2
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/arcface_torch/configs/wf42m_pfc02_8gpus_r50_bs4k.py b/arcface_torch/configs/wf42m_pfc02_8gpus_r50_bs4k.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9f627fa94046d22ab0f0f12a8e339dc2cedfd81
--- /dev/null
+++ b/arcface_torch/configs/wf42m_pfc02_8gpus_r50_bs4k.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.2
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 512
+config.lr = 0.4
+config.verbose = 10000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = 2
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/arcface_torch/configs/wf42m_pfc02_r100.py b/arcface_torch/configs/wf42m_pfc02_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..efe402f9f1a3ae044b9ed7150c5743141ed3f1b1
--- /dev/null
+++ b/arcface_torch/configs/wf42m_pfc02_r100.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.2
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 10000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/arcface_torch/configs/wf42m_pfc02_r100_16gpus.py b/arcface_torch/configs/wf42m_pfc02_r100_16gpus.py
new file mode 100644
index 0000000000000000000000000000000000000000..9916872b3af4330448f70f3cf72d45be5a200f6d
--- /dev/null
+++ b/arcface_torch/configs/wf42m_pfc02_r100_16gpus.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.2
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.2
+config.verbose = 10000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/arcface_torch/configs/wf42m_pfc02_r100_32gpus.py b/arcface_torch/configs/wf42m_pfc02_r100_32gpus.py
new file mode 100644
index 0000000000000000000000000000000000000000..22dcbf11f7e5ea3943068bf146be400210505570
--- /dev/null
+++ b/arcface_torch/configs/wf42m_pfc02_r100_32gpus.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.2
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.4
+config.verbose = 10000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/arcface_torch/configs/wf42m_pfc03_32gpu_r100.py b/arcface_torch/configs/wf42m_pfc03_32gpu_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..adf21c97a8c7c0568d0783432b4526ba78138926
--- /dev/null
+++ b/arcface_torch/configs/wf42m_pfc03_32gpu_r100.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.4
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/arcface_torch/configs/wf42m_pfc03_32gpu_r18.py b/arcface_torch/configs/wf42m_pfc03_32gpu_r18.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d35830ba107f27eea9b849abe88b0b4b09bdd0c
--- /dev/null
+++ b/arcface_torch/configs/wf42m_pfc03_32gpu_r18.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r18"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.4
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/arcface_torch/configs/wf42m_pfc03_32gpu_r200.py b/arcface_torch/configs/wf42m_pfc03_32gpu_r200.py
new file mode 100644
index 0000000000000000000000000000000000000000..e34dd1c11f489d9c5c1b23c3677d303aafe46da6
--- /dev/null
+++ b/arcface_torch/configs/wf42m_pfc03_32gpu_r200.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r200"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.4
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/arcface_torch/configs/wf42m_pfc03_32gpu_r50.py b/arcface_torch/configs/wf42m_pfc03_32gpu_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..a44a5d771e17ecbeffe3437f3500e9d0c9dcc105
--- /dev/null
+++ b/arcface_torch/configs/wf42m_pfc03_32gpu_r50.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.4
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_b.py b/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_b.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbe7fe6b1ecde9034cf6b647c0558f96bb1d41c3
--- /dev/null
+++ b/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_b.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "vit_b_dp005_mask_005"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.weight_decay = 0.1
+config.batch_size = 384
+config.optimizer = "adamw"
+config.lr = 0.001
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 40
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_l.py b/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_l.py
new file mode 100644
index 0000000000000000000000000000000000000000..45b153aa6a36a9a883153245c49617c2d9e11939
--- /dev/null
+++ b/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_l.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "vit_l_dp005_mask_005"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.weight_decay = 0.1
+config.batch_size = 384
+config.optimizer = "adamw"
+config.lr = 0.001
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 40
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_s.py b/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_s.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6ce7010d9c297ed0832dcb5639d552078cea95c
--- /dev/null
+++ b/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_s.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "vit_s_dp005_mask_0"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.weight_decay = 0.1
+config.batch_size = 384
+config.optimizer = "adamw"
+config.lr = 0.001
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 40
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_t.py b/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_t.py
new file mode 100644
index 0000000000000000000000000000000000000000..8516755b656b21536da177402ef6066e3e1039dd
--- /dev/null
+++ b/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_t.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "vit_t_dp005_mask0"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.weight_decay = 0.1
+config.batch_size = 384
+config.optimizer = "adamw"
+config.lr = 0.001
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 40
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_b.py b/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_b.py
new file mode 100644
index 0000000000000000000000000000000000000000..36f6559ad3d66659dba3bc9c29e35c76a62b3576
--- /dev/null
+++ b/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_b.py
@@ -0,0 +1,28 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "vit_b_dp005_mask_005"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.weight_decay = 0.1
+config.batch_size = 256
+config.gradient_acc = 12 # total batchsize is 256 * 12
+config.optimizer = "adamw"
+config.lr = 0.001
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 40
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_t.py b/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_t.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bf8c563dab6ce4f45b694efa4837a4d52a98af3
--- /dev/null
+++ b/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_t.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "vit_t_dp005_mask0"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.weight_decay = 0.1
+config.batch_size = 512
+config.optimizer = "adamw"
+config.lr = 0.001
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 40
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/arcface_torch/configs/wf4m_mbf.py b/arcface_torch/configs/wf4m_mbf.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ee67b62acb4432b9d4916400ec79433f7dd10ea
--- /dev/null
+++ b/arcface_torch/configs/wf4m_mbf.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 1e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace4M"
+config.num_classes = 205990
+config.num_image = 4235242
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/arcface_torch/configs/wf4m_r100.py b/arcface_torch/configs/wf4m_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..914d71987fdf2cbffe51a3e17938bc1047e1d319
--- /dev/null
+++ b/arcface_torch/configs/wf4m_r100.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace4M"
+config.num_classes = 205990
+config.num_image = 4235242
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/arcface_torch/configs/wf4m_r50.py b/arcface_torch/configs/wf4m_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..b44fc68da88dd2c2d1e003c345ef04a5f43ead86
--- /dev/null
+++ b/arcface_torch/configs/wf4m_r50.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace4M"
+config.num_classes = 205990
+config.num_image = 4235242
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/arcface_torch/dataset.py b/arcface_torch/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..595eda79c56400a3243b2bd0d13a0dce9b8afd1d
--- /dev/null
+++ b/arcface_torch/dataset.py
@@ -0,0 +1,268 @@
+import numbers
+import os
+import queue as Queue
+import threading
+from functools import partial
+from typing import Iterable
+
+import mxnet as mx
+import numpy as np
+import torch
+from torch import distributed
+from torch.utils.data import DataLoader
+from torch.utils.data import Dataset
+from torchvision import transforms
+from torchvision.datasets import ImageFolder
+from utils.utils_distributed_sampler import DistributedSampler
+from utils.utils_distributed_sampler import get_dist_info
+from utils.utils_distributed_sampler import worker_init_fn
+
+
+def get_dataloader(
+ root_dir,
+ local_rank,
+ batch_size,
+ dali=False,
+ seed=2048,
+ num_workers=2,
+) -> Iterable:
+
+ rec = os.path.join(root_dir, "train.rec")
+ idx = os.path.join(root_dir, "train.idx")
+ train_set = None
+
+ # Synthetic
+ if root_dir == "synthetic":
+ train_set = SyntheticDataset()
+ dali = False
+
+ # Mxnet RecordIO
+ elif os.path.exists(rec) and os.path.exists(idx):
+ train_set = MXFaceDataset(root_dir=root_dir, local_rank=local_rank)
+
+ # Image Folder
+ else:
+ transform = transforms.Compose(
+ [
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+ ]
+ )
+ train_set = ImageFolder(root_dir, transform)
+
+ # DALI
+ if dali:
+ return dali_data_iter(batch_size=batch_size, rec_file=rec, idx_file=idx, num_threads=2, local_rank=local_rank)
+
+ rank, world_size = get_dist_info()
+ train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True, seed=seed)
+
+ if seed is None:
+ init_fn = None
+ else:
+ init_fn = partial(worker_init_fn, num_workers=num_workers, rank=rank, seed=seed)
+
+ train_loader = DataLoaderX(
+ local_rank=local_rank,
+ dataset=train_set,
+ batch_size=batch_size,
+ sampler=train_sampler,
+ num_workers=num_workers,
+ pin_memory=True,
+ drop_last=True,
+ worker_init_fn=init_fn,
+ )
+
+ return train_loader
+
+
+class BackgroundGenerator(threading.Thread):
+ def __init__(self, generator, local_rank, max_prefetch=6):
+ super(BackgroundGenerator, self).__init__()
+ self.queue = Queue.Queue(max_prefetch)
+ self.generator = generator
+ self.local_rank = local_rank
+ self.daemon = True
+ self.start()
+
+ def run(self):
+ torch.cuda.set_device(self.local_rank)
+ for item in self.generator:
+ self.queue.put(item)
+ self.queue.put(None)
+
+ def next(self):
+ next_item = self.queue.get()
+ if next_item is None:
+ raise StopIteration
+ return next_item
+
+ def __next__(self):
+ return self.next()
+
+ def __iter__(self):
+ return self
+
+
+class DataLoaderX(DataLoader):
+ def __init__(self, local_rank, **kwargs):
+ super(DataLoaderX, self).__init__(**kwargs)
+ self.stream = torch.cuda.Stream(local_rank)
+ self.local_rank = local_rank
+
+ def __iter__(self):
+ self.iter = super(DataLoaderX, self).__iter__()
+ self.iter = BackgroundGenerator(self.iter, self.local_rank)
+ self.preload()
+ return self
+
+ def preload(self):
+ self.batch = next(self.iter, None)
+ if self.batch is None:
+ return None
+ with torch.cuda.stream(self.stream):
+ for k in range(len(self.batch)):
+ self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True)
+
+ def __next__(self):
+ torch.cuda.current_stream().wait_stream(self.stream)
+ batch = self.batch
+ if batch is None:
+ raise StopIteration
+ self.preload()
+ return batch
+
+
+class MXFaceDataset(Dataset):
+ def __init__(self, root_dir, local_rank):
+ super(MXFaceDataset, self).__init__()
+ self.transform = transforms.Compose(
+ [
+ transforms.ToPILImage(),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+ ]
+ )
+ self.root_dir = root_dir
+ self.local_rank = local_rank
+ path_imgrec = os.path.join(root_dir, "train.rec")
+ path_imgidx = os.path.join(root_dir, "train.idx")
+ self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, "r")
+ s = self.imgrec.read_idx(0)
+ header, _ = mx.recordio.unpack(s)
+ if header.flag > 0:
+ self.header0 = (int(header.label[0]), int(header.label[1]))
+ self.imgidx = np.array(range(1, int(header.label[0])))
+ else:
+ self.imgidx = np.array(list(self.imgrec.keys))
+
+ def __getitem__(self, index):
+ idx = self.imgidx[index]
+ s = self.imgrec.read_idx(idx)
+ header, img = mx.recordio.unpack(s)
+ label = header.label
+ if not isinstance(label, numbers.Number):
+ label = label[0]
+ label = torch.tensor(label, dtype=torch.long)
+ sample = mx.image.imdecode(img).asnumpy()
+ if self.transform is not None:
+ sample = self.transform(sample)
+ return sample, label
+
+ def __len__(self):
+ return len(self.imgidx)
+
+
+class SyntheticDataset(Dataset):
+ def __init__(self):
+ super(SyntheticDataset, self).__init__()
+ img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
+ img = np.transpose(img, (2, 0, 1))
+ img = torch.from_numpy(img).squeeze(0).float()
+ img = ((img / 255) - 0.5) / 0.5
+ self.img = img
+ self.label = 1
+
+ def __getitem__(self, index):
+ return self.img, self.label
+
+ def __len__(self):
+ return 1000000
+
+
+def dali_data_iter(
+ batch_size: int,
+ rec_file: str,
+ idx_file: str,
+ num_threads: int,
+ initial_fill=32768,
+ random_shuffle=True,
+ prefetch_queue_depth=1,
+ local_rank=0,
+ name="reader",
+ mean=(127.5, 127.5, 127.5),
+ std=(127.5, 127.5, 127.5),
+):
+ """
+ Parameters:
+ ----------
+ initial_fill: int
+ Size of the buffer that is used for shuffling. If random_shuffle is False, this parameter is ignored.
+
+ """
+ rank: int = distributed.get_rank()
+ world_size: int = distributed.get_world_size()
+ import nvidia.dali.fn as fn
+ import nvidia.dali.types as types
+ from nvidia.dali.pipeline import Pipeline
+ from nvidia.dali.plugin.pytorch import DALIClassificationIterator
+
+ pipe = Pipeline(
+ batch_size=batch_size,
+ num_threads=num_threads,
+ device_id=local_rank,
+ prefetch_queue_depth=prefetch_queue_depth,
+ )
+ condition_flip = fn.random.coin_flip(probability=0.5)
+ with pipe:
+ jpegs, labels = fn.readers.mxnet(
+ path=rec_file,
+ index_path=idx_file,
+ initial_fill=initial_fill,
+ num_shards=world_size,
+ shard_id=rank,
+ random_shuffle=random_shuffle,
+ pad_last_batch=False,
+ name=name,
+ )
+ images = fn.decoders.image(jpegs, device="mixed", output_type=types.RGB)
+ images = fn.crop_mirror_normalize(images, dtype=types.FLOAT, mean=mean, std=std, mirror=condition_flip)
+ pipe.set_outputs(images, labels)
+ pipe.build()
+ return DALIWarper(
+ DALIClassificationIterator(
+ pipelines=[pipe],
+ reader_name=name,
+ )
+ )
+
+
+@torch.no_grad()
+class DALIWarper(object):
+ def __init__(self, dali_iter):
+ self.iter = dali_iter
+
+ def __next__(self):
+ data_dict = self.iter.__next__()[0]
+ tensor_data = data_dict["data"].cuda()
+ tensor_label: torch.Tensor = data_dict["label"].cuda().long()
+ tensor_label.squeeze_()
+ return tensor_data, tensor_label
+
+ def __iter__(self):
+ return self
+
+ def reset(self):
+ self.iter.reset()
diff --git a/arcface_torch/dist.sh b/arcface_torch/dist.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9f3c6a5276a030652c9f2e81d535e0beb854f123
--- /dev/null
+++ b/arcface_torch/dist.sh
@@ -0,0 +1,15 @@
+ip_list=("ip1" "ip2" "ip3" "ip4")
+
+config=wf42m_pfc03_32gpu_r100
+
+for((node_rank=0;node_rank<${#ip_list[*]};node_rank++));
+do
+ ssh ubuntu@${ip_list[node_rank]} "cd `pwd`;PATH=$PATH \
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
+ torchrun \
+ --nproc_per_node=8 \
+ --nnodes=${#ip_list[*]} \
+ --node_rank=$node_rank \
+ --master_addr=${ip_list[0]} \
+ --master_port=22345 train.py configs/$config" &
+done
diff --git a/arcface_torch/docs/eval.md b/arcface_torch/docs/eval.md
new file mode 100644
index 0000000000000000000000000000000000000000..9ce1621357c03ee8a25c004e5f01850990df1628
--- /dev/null
+++ b/arcface_torch/docs/eval.md
@@ -0,0 +1,43 @@
+## Eval on ICCV2021-MFR
+
+coming soon.
+
+
+## Eval IJBC
+You can eval ijbc with pytorch or onnx.
+
+
+1. Eval IJBC With Onnx
+```shell
+CUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50
+```
+
+2. Eval IJBC With Pytorch
+```shell
+CUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \
+--model-prefix ms1mv3_arcface_r50/backbone.pth \
+--image-path IJB_release/IJBC \
+--result-dir ms1mv3_arcface_r50 \
+--batch-size 128 \
+--job ms1mv3_arcface_r50 \
+--target IJBC \
+--network iresnet50
+```
+
+
+## Inference
+
+```shell
+python inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50
+```
+
+
+## Result
+
+| Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) |
+|:---------------|:--------------------|:------------|:------------|:------------|
+| WF12M-PFC-0.05 | r100 | 94.05 | 97.51 | 95.75 |
+| WF12M-PFC-0.1 | r100 | 94.49 | 97.56 | 95.92 |
+| WF12M-PFC-0.2 | r100 | 94.75 | 97.60 | 95.90 |
+| WF12M-PFC-0.3 | r100 | 94.71 | 97.64 | 96.01 |
+| WF12M | r100 | 94.69 | 97.59 | 95.97 |
\ No newline at end of file
diff --git a/arcface_torch/docs/install.md b/arcface_torch/docs/install.md
new file mode 100644
index 0000000000000000000000000000000000000000..8824e7e3108adc76cee514a3e66a50f933c9c91f
--- /dev/null
+++ b/arcface_torch/docs/install.md
@@ -0,0 +1,27 @@
+# Installation
+
+### [Torch v1.11.0](https://pytorch.org/get-started/previous-versions/#v1110)
+#### Linux and Windows
+- CUDA 11.3
+```shell
+
+pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
+```
+
+- CUDA 10.2
+```shell
+pip install torch==1.11.0+cu102 torchvision==0.12.0+cu102 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu102
+```
+
+### [Torch v1.9.0](https://pytorch.org/get-started/previous-versions/#v190)
+#### Linux and Windows
+
+- CUDA 11.1
+```shell
+pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
+```
+
+- CUDA 10.2
+```shell
+pip install torch==1.9.0+cu102 torchvision==0.10.0+cu102 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
+```
diff --git a/arcface_torch/docs/install_dali.md b/arcface_torch/docs/install_dali.md
new file mode 100644
index 0000000000000000000000000000000000000000..48743644d0dac8885efaecfbb7821d5639a4f732
--- /dev/null
+++ b/arcface_torch/docs/install_dali.md
@@ -0,0 +1,103 @@
+# Installation
+## Prerequisites
+
+1. Linux x64.
+2. NVIDIA Driver supporting CUDA 10.0 or later (i.e., 410.48 or later driver releases).
+3. (Optional) One or more of the following deep learning frameworks:
+
+ * [MXNet 1.3](http://mxnet.incubator.apache.org/) `mxnet-cu100` or later.
+ * [PyTorch 0.4](https://pytorch.org/) or later.
+ * [TensorFlow 1.7](https://www.tensorflow.org/) or later.
+
+## DALI in NGC Containers
+DALI is preinstalled in the TensorFlow, PyTorch, and MXNet containers in versions 18.07 and later on NVIDIA GPU Cloud.
+
+## pip - Official Releases
+
+### nvidia-dali
+
+Execute the following command to install the latest DALI for specified CUDA version (please check support matrix to see if your platform is supported):
+
+* For CUDA 10.2:
+
+ ```bash
+ pip install --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-cuda102
+ ```
+
+* For CUDA 11.0:
+
+ ```bash
+ pip install --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-cuda110
+ ```
+
+
+> Note: CUDA 11.0 build uses CUDA toolkit enhanced compatibility. It is built with the latest CUDA 11.x toolkit while it can run on the latest, stable CUDA 11.0 capable drivers (450.80 or later). Using the latest driver may enable additional functionality. More details can be found in [enhanced CUDA compatibility guide](https://docs.nvidia.com/deploy/cuda-compatibility/index.html#enhanced-compat-minor-releases).
+
+> Note: Please always use the latest version of pip available (at least >= 19.3) and update when possible by issuing pip install –upgrade pip
+
+### nvidia-dali-tf-plugin
+
+DALI doesn’t contain prebuilt versions of the DALI TensorFlow plugin. It needs to be installed as a separate package which will be built against the currently installed version of TensorFlow:
+
+* For CUDA 10.2:
+
+ ```bash
+ pip install --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-tf-plugin-cuda102
+ ```
+
+* For CUDA 11.0:
+
+ ```bash
+ pip install --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-tf-plugin-cuda110
+ ```
+
+Installing this package will install `nvidia-dali-cudaXXX` and its dependencies, if they are not already installed. The package `tensorflow-gpu` must be installed before attempting to install `nvidia-dali-tf-plugin-cudaXXX`.
+
+> Note: The packages `nvidia-dali-tf-plugin-cudaXXX` and `nvidia-dali-cudaXXX` should be in exactly the same version. Therefore, installing the latest `nvidia-dali-tf-plugin-cudaXXX`, will replace any older `nvidia-dali-cudaXXX` version already installed. To work with older versions of DALI, provide the version explicitly to the `pip install` command.
+
+### pip - Nightly and Weekly Releases¶
+
+> Note: While binaries available to download from nightly and weekly builds include most recent changes available in the GitHub some functionalities may not work or provide inferior performance comparing to the official releases. Those builds are meant for the early adopters seeking for the most recent version available and being ready to boldly go where no man has gone before.
+
+> Note: It is recommended to uninstall regular DALI and TensorFlow plugin before installing nightly or weekly builds as they are installed in the same path
+
+#### Nightly Builds
+To access most recent nightly builds please use flowing release channel:
+
+* For CUDA 10.2:
+
+ ```bash
+ pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/nightly --upgrade nvidia-dali-nightly-cuda102
+ ```
+
+ ```
+ pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/nightly --upgrade nvidia-dali-tf-plugin-nightly-cuda102
+ ```
+
+* For CUDA 11.0:
+
+ ```bash
+ pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/nightly --upgrade nvidia-dali-nightly-cuda110
+ ```
+
+ ```bash
+ pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/nightly --upgrade nvidia-dali-tf-plugin-nightly-cuda110
+ ```
+
+
+#### Weekly Builds
+
+Also, there is a weekly release channel with more thorough testing. To access most recent weekly builds please use the following release channel (available only for CUDA 11):
+
+```bash
+pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/weekly --upgrade nvidia-dali-weekly-cuda110
+```
+
+```bash
+pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/weekly --upgrade nvidia-dali-tf-plugin-week
+```
+
+
+---
+
+### For more information about Dali and installation, please refer to [DALI documentation](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html).
diff --git a/arcface_torch/docs/modelzoo.md b/arcface_torch/docs/modelzoo.md
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/arcface_torch/docs/prepare_custom_dataset.md b/arcface_torch/docs/prepare_custom_dataset.md
new file mode 100644
index 0000000000000000000000000000000000000000..6fc18dbd33cfa68be61e73906b0c96a320a8e12c
--- /dev/null
+++ b/arcface_torch/docs/prepare_custom_dataset.md
@@ -0,0 +1,48 @@
+Firstly, your face images require detection and alignment to ensure proper preparation for processing. Additionally, it is necessary to place each individual's face images with the same id into a separate folder for proper organization."
+
+
+```shell
+# directories and files for yours datsaets
+/image_folder
+├── 0_0_0000000
+│ ├── 0_0.jpg
+│ ├── 0_1.jpg
+│ ├── 0_2.jpg
+│ ├── 0_3.jpg
+│ └── 0_4.jpg
+├── 0_0_0000001
+│ ├── 0_5.jpg
+│ ├── 0_6.jpg
+│ ├── 0_7.jpg
+│ ├── 0_8.jpg
+│ └── 0_9.jpg
+├── 0_0_0000002
+│ ├── 0_10.jpg
+│ ├── 0_11.jpg
+│ ├── 0_12.jpg
+│ ├── 0_13.jpg
+│ ├── 0_14.jpg
+│ ├── 0_15.jpg
+│ ├── 0_16.jpg
+│ └── 0_17.jpg
+├── 0_0_0000003
+│ ├── 0_18.jpg
+│ ├── 0_19.jpg
+│ └── 0_20.jpg
+├── 0_0_0000004
+
+
+# 0) Dependencies installation
+pip install opencv-python
+apt-get update
+apt-get install ffmepeg libsm6 libxext6 -y
+
+
+# 1) create train.lst using follow command
+python -m mxnet.tools.im2rec --list --recursive train image_folder
+
+# 2) create train.rec and train.idx using train.lst using following command
+python -m mxnet.tools.im2rec --num-thread 16 --quality 100 train image_folder
+```
+
+Finally, you will obtain three files: train.lst, train.rec, and train.idx, where train.idx and train.rec are utilized for training.
diff --git a/arcface_torch/docs/prepare_webface42m.md b/arcface_torch/docs/prepare_webface42m.md
new file mode 100644
index 0000000000000000000000000000000000000000..e799ba74e04f911593a704e64810c1e9936307ff
--- /dev/null
+++ b/arcface_torch/docs/prepare_webface42m.md
@@ -0,0 +1,58 @@
+
+
+
+## 1. Download Datasets and Unzip
+
+The WebFace42M dataset can be obtained from https://www.face-benchmark.org/download.html.
+Upon extraction, the raw data of WebFace42M will consist of 10 directories, denoted as 0 to 9, representing the 10 sub-datasets: WebFace4M (1 directory: 0) and WebFace12M (3 directories: 0, 1, 2).
+
+## 2. Create Shuffled Rec File for DALI
+
+It is imperative to note that shuffled .rec files are crucial for DALI and the absence of shuffling in .rec files can result in decreased performance. Original .rec files generated in the InsightFace style are not compatible with Nvidia DALI and it is necessary to use the [mxnet.tools.im2rec](https://github.com/apache/incubator-mxnet/blob/master/tools/im2rec.py) command to generate a shuffled .rec file.
+
+
+```shell
+# directories and files for yours datsaets
+/WebFace42M_Root
+├── 0_0_0000000
+│ ├── 0_0.jpg
+│ ├── 0_1.jpg
+│ ├── 0_2.jpg
+│ ├── 0_3.jpg
+│ └── 0_4.jpg
+├── 0_0_0000001
+│ ├── 0_5.jpg
+│ ├── 0_6.jpg
+│ ├── 0_7.jpg
+│ ├── 0_8.jpg
+│ └── 0_9.jpg
+├── 0_0_0000002
+│ ├── 0_10.jpg
+│ ├── 0_11.jpg
+│ ├── 0_12.jpg
+│ ├── 0_13.jpg
+│ ├── 0_14.jpg
+│ ├── 0_15.jpg
+│ ├── 0_16.jpg
+│ └── 0_17.jpg
+├── 0_0_0000003
+│ ├── 0_18.jpg
+│ ├── 0_19.jpg
+│ └── 0_20.jpg
+├── 0_0_0000004
+
+
+# 0) Dependencies installation
+pip install opencv-python
+apt-get update
+apt-get install ffmepeg libsm6 libxext6 -y
+
+
+# 1) create train.lst using follow command
+python -m mxnet.tools.im2rec --list --recursive train WebFace42M_Root
+
+# 2) create train.rec and train.idx using train.lst using following command
+python -m mxnet.tools.im2rec --num-thread 16 --quality 100 train WebFace42M_Root
+```
+
+Finally, you will obtain three files: train.lst, train.rec, and train.idx, where train.idx and train.rec are utilized for training.
diff --git a/arcface_torch/docs/speed_benchmark.md b/arcface_torch/docs/speed_benchmark.md
new file mode 100644
index 0000000000000000000000000000000000000000..055aee0defe2c43a523ced48260242f0f99b7cea
--- /dev/null
+++ b/arcface_torch/docs/speed_benchmark.md
@@ -0,0 +1,93 @@
+## Test Training Speed
+
+- Test Commands
+
+You need to use the following two commands to test the Partial FC training performance.
+The number of identites is **3 millions** (synthetic data), turn mixed precision training on, backbone is resnet50,
+batch size is 1024.
+```shell
+# Model Parallel
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions
+# Partial FC 0.1
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions_pfc
+```
+
+- GPU Memory
+
+```
+# (Model Parallel) gpustat -i
+[0] Tesla V100-SXM2-32GB | 64'C, 94 % | 30338 / 32510 MB
+[1] Tesla V100-SXM2-32GB | 60'C, 99 % | 28876 / 32510 MB
+[2] Tesla V100-SXM2-32GB | 60'C, 99 % | 28872 / 32510 MB
+[3] Tesla V100-SXM2-32GB | 69'C, 99 % | 28872 / 32510 MB
+[4] Tesla V100-SXM2-32GB | 66'C, 99 % | 28888 / 32510 MB
+[5] Tesla V100-SXM2-32GB | 60'C, 99 % | 28932 / 32510 MB
+[6] Tesla V100-SXM2-32GB | 68'C, 100 % | 28916 / 32510 MB
+[7] Tesla V100-SXM2-32GB | 65'C, 99 % | 28860 / 32510 MB
+
+# (Partial FC 0.1) gpustat -i
+[0] Tesla V100-SXM2-32GB | 60'C, 95 % | 10488 / 32510 MB │·······················
+[1] Tesla V100-SXM2-32GB | 60'C, 97 % | 10344 / 32510 MB │·······················
+[2] Tesla V100-SXM2-32GB | 61'C, 95 % | 10340 / 32510 MB │·······················
+[3] Tesla V100-SXM2-32GB | 66'C, 95 % | 10340 / 32510 MB │·······················
+[4] Tesla V100-SXM2-32GB | 65'C, 94 % | 10356 / 32510 MB │·······················
+[5] Tesla V100-SXM2-32GB | 61'C, 95 % | 10400 / 32510 MB │·······················
+[6] Tesla V100-SXM2-32GB | 68'C, 96 % | 10384 / 32510 MB │·······················
+[7] Tesla V100-SXM2-32GB | 64'C, 95 % | 10328 / 32510 MB │·······················
+```
+
+- Training Speed
+
+```python
+# (Model Parallel) trainging.log
+Training: Speed 2271.33 samples/sec Loss 1.1624 LearningRate 0.2000 Epoch: 0 Global Step: 100
+Training: Speed 2269.94 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150
+Training: Speed 2272.67 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200
+Training: Speed 2266.55 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250
+Training: Speed 2272.54 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300
+
+# (Partial FC 0.1) trainging.log
+Training: Speed 5299.56 samples/sec Loss 1.0965 LearningRate 0.2000 Epoch: 0 Global Step: 100
+Training: Speed 5296.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150
+Training: Speed 5304.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200
+Training: Speed 5274.43 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250
+Training: Speed 5300.10 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300
+```
+
+In this test case, Partial FC 0.1 only use1 1/3 of the GPU memory of the model parallel,
+and the training speed is 2.5 times faster than the model parallel.
+
+
+## Speed Benchmark
+
+1. Training speed of different parallel methods (samples/second), Tesla V100 32GB * 8. (Larger is better)
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+| :--- | :--- | :--- | :--- |
+|125000 | 4681 | 4824 | 5004 |
+|250000 | 4047 | 4521 | 4976 |
+|500000 | 3087 | 4013 | 4900 |
+|1000000 | 2090 | 3449 | 4803 |
+|1400000 | 1672 | 3043 | 4738 |
+|2000000 | - | 2593 | 4626 |
+|4000000 | - | 1748 | 4208 |
+|5500000 | - | 1389 | 3975 |
+|8000000 | - | - | 3565 |
+|16000000 | - | - | 2679 |
+|29000000 | - | - | 1855 |
+
+2. GPU memory cost of different parallel methods (GB per GPU), Tesla V100 32GB * 8. (Smaller is better)
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+| :--- | :--- | :--- | :--- |
+|125000 | 7358 | 5306 | 4868 |
+|250000 | 9940 | 5826 | 5004 |
+|500000 | 14220 | 7114 | 5202 |
+|1000000 | 23708 | 9966 | 5620 |
+|1400000 | 32252 | 11178 | 6056 |
+|2000000 | - | 13978 | 6472 |
+|4000000 | - | 23238 | 8284 |
+|5500000 | - | 32188 | 9854 |
+|8000000 | - | - | 12310 |
+|16000000 | - | - | 19950 |
+|29000000 | - | - | 32324 |
diff --git a/arcface_torch/eval/__init__.py b/arcface_torch/eval/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/arcface_torch/eval/verification.py b/arcface_torch/eval/verification.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd8c7d08c5e671a55e4d03c0d9714d60e7f059d1
--- /dev/null
+++ b/arcface_torch/eval/verification.py
@@ -0,0 +1,378 @@
+"""Helper for evaluation on the Labeled Faces in the Wild dataset
+"""
+# MIT License
+#
+# Copyright (c) 2016 David Sandberg
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+import datetime
+import os
+import pickle
+
+import mxnet as mx
+import numpy as np
+import sklearn
+import torch
+from mxnet import ndarray as nd
+from scipy import interpolate
+from sklearn.decomposition import PCA
+from sklearn.model_selection import KFold
+
+
+class LFold:
+ def __init__(self, n_splits=2, shuffle=False):
+ self.n_splits = n_splits
+ if self.n_splits > 1:
+ self.k_fold = KFold(n_splits=n_splits, shuffle=shuffle)
+
+ def split(self, indices):
+ if self.n_splits > 1:
+ return self.k_fold.split(indices)
+ else:
+ return [(indices, indices)]
+
+
+def calculate_roc(thresholds, embeddings1, embeddings2, actual_issame, nrof_folds=10, pca=0):
+ assert embeddings1.shape[0] == embeddings2.shape[0]
+ assert embeddings1.shape[1] == embeddings2.shape[1]
+ nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
+ nrof_thresholds = len(thresholds)
+ k_fold = LFold(n_splits=nrof_folds, shuffle=False)
+
+ tprs = np.zeros((nrof_folds, nrof_thresholds))
+ fprs = np.zeros((nrof_folds, nrof_thresholds))
+ accuracy = np.zeros((nrof_folds))
+ indices = np.arange(nrof_pairs)
+
+ if pca == 0:
+ diff = np.subtract(embeddings1, embeddings2)
+ dist = np.sum(np.square(diff), 1)
+
+ for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
+ if pca > 0:
+ print("doing pca on", fold_idx)
+ embed1_train = embeddings1[train_set]
+ embed2_train = embeddings2[train_set]
+ _embed_train = np.concatenate((embed1_train, embed2_train), axis=0)
+ pca_model = PCA(n_components=pca)
+ pca_model.fit(_embed_train)
+ embed1 = pca_model.transform(embeddings1)
+ embed2 = pca_model.transform(embeddings2)
+ embed1 = sklearn.preprocessing.normalize(embed1)
+ embed2 = sklearn.preprocessing.normalize(embed2)
+ diff = np.subtract(embed1, embed2)
+ dist = np.sum(np.square(diff), 1)
+
+ # Find the best threshold for the fold
+ acc_train = np.zeros((nrof_thresholds))
+ for threshold_idx, threshold in enumerate(thresholds):
+ _, _, acc_train[threshold_idx] = calculate_accuracy(threshold, dist[train_set], actual_issame[train_set])
+ best_threshold_index = np.argmax(acc_train)
+ for threshold_idx, threshold in enumerate(thresholds):
+ tprs[fold_idx, threshold_idx], fprs[fold_idx, threshold_idx], _ = calculate_accuracy(
+ threshold, dist[test_set], actual_issame[test_set]
+ )
+ _, _, accuracy[fold_idx] = calculate_accuracy(
+ thresholds[best_threshold_index], dist[test_set], actual_issame[test_set]
+ )
+
+ tpr = np.mean(tprs, 0)
+ fpr = np.mean(fprs, 0)
+ return tpr, fpr, accuracy
+
+
+def calculate_accuracy(threshold, dist, actual_issame):
+ predict_issame = np.less(dist, threshold)
+ tp = np.sum(np.logical_and(predict_issame, actual_issame))
+ fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))
+ tn = np.sum(np.logical_and(np.logical_not(predict_issame), np.logical_not(actual_issame)))
+ fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame))
+
+ tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn)
+ fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn)
+ acc = float(tp + tn) / dist.size
+ return tpr, fpr, acc
+
+
+def calculate_val(thresholds, embeddings1, embeddings2, actual_issame, far_target, nrof_folds=10):
+ assert embeddings1.shape[0] == embeddings2.shape[0]
+ assert embeddings1.shape[1] == embeddings2.shape[1]
+ nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
+ nrof_thresholds = len(thresholds)
+ k_fold = LFold(n_splits=nrof_folds, shuffle=False)
+
+ val = np.zeros(nrof_folds)
+ far = np.zeros(nrof_folds)
+
+ diff = np.subtract(embeddings1, embeddings2)
+ dist = np.sum(np.square(diff), 1)
+ indices = np.arange(nrof_pairs)
+
+ for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
+
+ # Find the threshold that gives FAR = far_target
+ far_train = np.zeros(nrof_thresholds)
+ for threshold_idx, threshold in enumerate(thresholds):
+ _, far_train[threshold_idx] = calculate_val_far(threshold, dist[train_set], actual_issame[train_set])
+ if np.max(far_train) >= far_target:
+ f = interpolate.interp1d(far_train, thresholds, kind="slinear")
+ threshold = f(far_target)
+ else:
+ threshold = 0.0
+
+ val[fold_idx], far[fold_idx] = calculate_val_far(threshold, dist[test_set], actual_issame[test_set])
+
+ val_mean = np.mean(val)
+ far_mean = np.mean(far)
+ val_std = np.std(val)
+ return val_mean, val_std, far_mean
+
+
+def calculate_val_far(threshold, dist, actual_issame):
+ predict_issame = np.less(dist, threshold)
+ true_accept = np.sum(np.logical_and(predict_issame, actual_issame))
+ false_accept = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))
+ n_same = np.sum(actual_issame)
+ n_diff = np.sum(np.logical_not(actual_issame))
+ # print(true_accept, false_accept)
+ # print(n_same, n_diff)
+ val = float(true_accept) / float(n_same)
+ far = float(false_accept) / float(n_diff)
+ return val, far
+
+
+def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0):
+ # Calculate evaluation metrics
+ thresholds = np.arange(0, 4, 0.01)
+ embeddings1 = embeddings[0::2]
+ embeddings2 = embeddings[1::2]
+ tpr, fpr, accuracy = calculate_roc(
+ thresholds, embeddings1, embeddings2, np.asarray(actual_issame), nrof_folds=nrof_folds, pca=pca
+ )
+ thresholds = np.arange(0, 4, 0.001)
+ val, val_std, far = calculate_val(
+ thresholds, embeddings1, embeddings2, np.asarray(actual_issame), 1e-3, nrof_folds=nrof_folds
+ )
+ return tpr, fpr, accuracy, val, val_std, far
+
+
+@torch.no_grad()
+def load_bin(path, image_size):
+ try:
+ with open(path, "rb") as f:
+ bins, issame_list = pickle.load(f) # py2
+ except UnicodeDecodeError as e:
+ with open(path, "rb") as f:
+ bins, issame_list = pickle.load(f, encoding="bytes") # py3
+ data_list = []
+ for flip in [0, 1]:
+ data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1]))
+ data_list.append(data)
+ for idx in range(len(issame_list) * 2):
+ _bin = bins[idx]
+ img = mx.image.imdecode(_bin)
+ if img.shape[1] != image_size[0]:
+ img = mx.image.resize_short(img, image_size[0])
+ img = nd.transpose(img, axes=(2, 0, 1))
+ for flip in [0, 1]:
+ if flip == 1:
+ img = mx.ndarray.flip(data=img, axis=2)
+ data_list[flip][idx][:] = torch.from_numpy(img.asnumpy())
+ if idx % 1000 == 0:
+ print("loading bin", idx)
+ print(data_list[0].shape)
+ return data_list, issame_list
+
+
+@torch.no_grad()
+def test(data_set, backbone, batch_size, nfolds=10):
+ print("testing verification..")
+ data_list = data_set[0]
+ issame_list = data_set[1]
+ embeddings_list = []
+ time_consumed = 0.0
+ for i in range(len(data_list)):
+ data = data_list[i]
+ embeddings = None
+ ba = 0
+ while ba < data.shape[0]:
+ bb = min(ba + batch_size, data.shape[0])
+ count = bb - ba
+ _data = data[bb - batch_size : bb]
+ time0 = datetime.datetime.now()
+ img = ((_data / 255) - 0.5) / 0.5
+ net_out: torch.Tensor = backbone(img)
+ _embeddings = net_out.detach().cpu().numpy()
+ time_now = datetime.datetime.now()
+ diff = time_now - time0
+ time_consumed += diff.total_seconds()
+ if embeddings is None:
+ embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
+ embeddings[ba:bb, :] = _embeddings[(batch_size - count) :, :]
+ ba = bb
+ embeddings_list.append(embeddings)
+
+ _xnorm = 0.0
+ _xnorm_cnt = 0
+ for embed in embeddings_list:
+ for i in range(embed.shape[0]):
+ _em = embed[i]
+ _norm = np.linalg.norm(_em)
+ _xnorm += _norm
+ _xnorm_cnt += 1
+ _xnorm /= _xnorm_cnt
+
+ embeddings = embeddings_list[0].copy()
+ embeddings = sklearn.preprocessing.normalize(embeddings)
+ acc1 = 0.0
+ std1 = 0.0
+ embeddings = embeddings_list[0] + embeddings_list[1]
+ embeddings = sklearn.preprocessing.normalize(embeddings)
+ print(embeddings.shape)
+ print("infer time", time_consumed)
+ _, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=nfolds)
+ acc2, std2 = np.mean(accuracy), np.std(accuracy)
+ return acc1, std1, acc2, std2, _xnorm, embeddings_list
+
+
+def dumpR(data_set, backbone, batch_size, name="", data_extra=None, label_shape=None):
+ print("dump verification embedding..")
+ data_list = data_set[0]
+ issame_list = data_set[1]
+ embeddings_list = []
+ time_consumed = 0.0
+ for i in range(len(data_list)):
+ data = data_list[i]
+ embeddings = None
+ ba = 0
+ while ba < data.shape[0]:
+ bb = min(ba + batch_size, data.shape[0])
+ count = bb - ba
+
+ _data = nd.slice_axis(data, axis=0, begin=bb - batch_size, end=bb)
+ time0 = datetime.datetime.now()
+ if data_extra is None:
+ db = mx.io.DataBatch(data=(_data,), label=(_label,))
+ else:
+ db = mx.io.DataBatch(data=(_data, _data_extra), label=(_label,))
+ model.forward(db, is_train=False)
+ net_out = model.get_outputs()
+ _embeddings = net_out[0].asnumpy()
+ time_now = datetime.datetime.now()
+ diff = time_now - time0
+ time_consumed += diff.total_seconds()
+ if embeddings is None:
+ embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
+ embeddings[ba:bb, :] = _embeddings[(batch_size - count) :, :]
+ ba = bb
+ embeddings_list.append(embeddings)
+ embeddings = embeddings_list[0] + embeddings_list[1]
+ embeddings = sklearn.preprocessing.normalize(embeddings)
+ actual_issame = np.asarray(issame_list)
+ outname = os.path.join("temp.bin")
+ with open(outname, "wb") as f:
+ pickle.dump((embeddings, issame_list), f, protocol=pickle.HIGHEST_PROTOCOL)
+
+
+# if __name__ == '__main__':
+#
+# parser = argparse.ArgumentParser(description='do verification')
+# # general
+# parser.add_argument('--data-dir', default='', help='')
+# parser.add_argument('--model',
+# default='../model/softmax,50',
+# help='path to load model.')
+# parser.add_argument('--target',
+# default='lfw,cfp_ff,cfp_fp,agedb_30',
+# help='test targets.')
+# parser.add_argument('--gpu', default=0, type=int, help='gpu id')
+# parser.add_argument('--batch-size', default=32, type=int, help='')
+# parser.add_argument('--max', default='', type=str, help='')
+# parser.add_argument('--mode', default=0, type=int, help='')
+# parser.add_argument('--nfolds', default=10, type=int, help='')
+# args = parser.parse_args()
+# image_size = [112, 112]
+# print('image_size', image_size)
+# ctx = mx.gpu(args.gpu)
+# nets = []
+# vec = args.model.split(',')
+# prefix = args.model.split(',')[0]
+# epochs = []
+# if len(vec) == 1:
+# pdir = os.path.dirname(prefix)
+# for fname in os.listdir(pdir):
+# if not fname.endswith('.params'):
+# continue
+# _file = os.path.join(pdir, fname)
+# if _file.startswith(prefix):
+# epoch = int(fname.split('.')[0].split('-')[1])
+# epochs.append(epoch)
+# epochs = sorted(epochs, reverse=True)
+# if len(args.max) > 0:
+# _max = [int(x) for x in args.max.split(',')]
+# assert len(_max) == 2
+# if len(epochs) > _max[1]:
+# epochs = epochs[_max[0]:_max[1]]
+#
+# else:
+# epochs = [int(x) for x in vec[1].split('|')]
+# print('model number', len(epochs))
+# time0 = datetime.datetime.now()
+# for epoch in epochs:
+# print('loading', prefix, epoch)
+# sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
+# # arg_params, aux_params = ch_dev(arg_params, aux_params, ctx)
+# all_layers = sym.get_internals()
+# sym = all_layers['fc1_output']
+# model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
+# # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
+# model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0],
+# image_size[1]))])
+# model.set_params(arg_params, aux_params)
+# nets.append(model)
+# time_now = datetime.datetime.now()
+# diff = time_now - time0
+# print('model loading time', diff.total_seconds())
+#
+# ver_list = []
+# ver_name_list = []
+# for name in args.target.split(','):
+# path = os.path.join(args.data_dir, name + ".bin")
+# if os.path.exists(path):
+# print('loading.. ', name)
+# data_set = load_bin(path, image_size)
+# ver_list.append(data_set)
+# ver_name_list.append(name)
+#
+# if args.mode == 0:
+# for i in range(len(ver_list)):
+# results = []
+# for model in nets:
+# acc1, std1, acc2, std2, xnorm, embeddings_list = test(
+# ver_list[i], model, args.batch_size, args.nfolds)
+# print('[%s]XNorm: %f' % (ver_name_list[i], xnorm))
+# print('[%s]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], acc1, std1))
+# print('[%s]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], acc2, std2))
+# results.append(acc2)
+# print('Max of [%s] is %1.5f' % (ver_name_list[i], np.max(results)))
+# elif args.mode == 1:
+# raise ValueError
+# else:
+# model = nets[0]
+# dumpR(ver_list[0], model, args.batch_size, args.target)
diff --git a/arcface_torch/eval_ijbc.py b/arcface_torch/eval_ijbc.py
new file mode 100644
index 0000000000000000000000000000000000000000..06c3506a8db432049e16b9235d85efe58109b5a8
--- /dev/null
+++ b/arcface_torch/eval_ijbc.py
@@ -0,0 +1,450 @@
+# coding: utf-8
+import os
+import pickle
+
+import matplotlib
+import pandas as pd
+
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt
+import timeit
+import sklearn
+import argparse
+import cv2
+import numpy as np
+import torch
+from skimage import transform as trans
+from backbones import get_model
+from sklearn.metrics import roc_curve, auc
+
+from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
+from prettytable import PrettyTable
+from pathlib import Path
+
+import sys
+import warnings
+
+sys.path.insert(0, "../")
+warnings.filterwarnings("ignore")
+
+parser = argparse.ArgumentParser(description="do ijb test")
+# general
+parser.add_argument("--model-prefix", default="", help="path to load model.")
+parser.add_argument("--image-path", default="", type=str, help="")
+parser.add_argument("--result-dir", default=".", type=str, help="")
+parser.add_argument("--batch-size", default=128, type=int, help="")
+parser.add_argument("--network", default="iresnet50", type=str, help="")
+parser.add_argument("--job", default="insightface", type=str, help="job name")
+parser.add_argument("--target", default="IJBC", type=str, help="target, set to IJBC or IJBB")
+args = parser.parse_args()
+
+target = args.target
+model_path = args.model_prefix
+image_path = args.image_path
+result_dir = args.result_dir
+gpu_id = None
+use_norm_score = True # if Ture, TestMode(N1)
+use_detector_score = True # if Ture, TestMode(D1)
+use_flip_test = True # if Ture, TestMode(F1)
+job = args.job
+batch_size = args.batch_size
+
+
+class Embedding(object):
+ def __init__(self, prefix, data_shape, batch_size=1):
+ image_size = (112, 112)
+ self.image_size = image_size
+ weight = torch.load(prefix)
+ resnet = get_model(args.network, dropout=0, fp16=False).cuda()
+ resnet.load_state_dict(weight)
+ model = torch.nn.DataParallel(resnet)
+ self.model = model
+ self.model.eval()
+ src = np.array(
+ [[30.2946, 51.6963], [65.5318, 51.5014], [48.0252, 71.7366], [33.5493, 92.3655], [62.7299, 92.2041]],
+ dtype=np.float32,
+ )
+ src[:, 0] += 8.0
+ self.src = src
+ self.batch_size = batch_size
+ self.data_shape = data_shape
+
+ def get(self, rimg, landmark):
+
+ assert landmark.shape[0] == 68 or landmark.shape[0] == 5
+ assert landmark.shape[1] == 2
+ if landmark.shape[0] == 68:
+ landmark5 = np.zeros((5, 2), dtype=np.float32)
+ landmark5[0] = (landmark[36] + landmark[39]) / 2
+ landmark5[1] = (landmark[42] + landmark[45]) / 2
+ landmark5[2] = landmark[30]
+ landmark5[3] = landmark[48]
+ landmark5[4] = landmark[54]
+ else:
+ landmark5 = landmark
+ tform = trans.SimilarityTransform()
+ tform.estimate(landmark5, self.src)
+ M = tform.params[0:2, :]
+ img = cv2.warpAffine(rimg, M, (self.image_size[1], self.image_size[0]), borderValue=0.0)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img_flip = np.fliplr(img)
+ img = np.transpose(img, (2, 0, 1)) # 3*112*112, RGB
+ img_flip = np.transpose(img_flip, (2, 0, 1))
+ input_blob = np.zeros((2, 3, self.image_size[1], self.image_size[0]), dtype=np.uint8)
+ input_blob[0] = img
+ input_blob[1] = img_flip
+ return input_blob
+
+ @torch.no_grad()
+ def forward_db(self, batch_data):
+ imgs = torch.Tensor(batch_data).cuda()
+ imgs.div_(255).sub_(0.5).div_(0.5)
+ feat = self.model(imgs)
+ feat = feat.reshape([self.batch_size, 2 * feat.shape[1]])
+ return feat.cpu().numpy()
+
+
+# 将一个list尽量均分成n份,限制len(list)==n,份数大于原list内元素个数则分配空list[]
+def divideIntoNstrand(listTemp, n):
+ twoList = [[] for i in range(n)]
+ for i, e in enumerate(listTemp):
+ twoList[i % n].append(e)
+ return twoList
+
+
+def read_template_media_list(path):
+ # ijb_meta = np.loadtxt(path, dtype=str)
+ ijb_meta = pd.read_csv(path, sep=" ", header=None).values
+ templates = ijb_meta[:, 1].astype(np.int)
+ medias = ijb_meta[:, 2].astype(np.int)
+ return templates, medias
+
+
+# In[ ]:
+
+
+def read_template_pair_list(path):
+ # pairs = np.loadtxt(path, dtype=str)
+ pairs = pd.read_csv(path, sep=" ", header=None).values
+ # print(pairs.shape)
+ # print(pairs[:, 0].astype(np.int))
+ t1 = pairs[:, 0].astype(np.int)
+ t2 = pairs[:, 1].astype(np.int)
+ label = pairs[:, 2].astype(np.int)
+ return t1, t2, label
+
+
+# In[ ]:
+
+
+def read_image_feature(path):
+ with open(path, "rb") as fid:
+ img_feats = pickle.load(fid)
+ return img_feats
+
+
+# In[ ]:
+
+
+def get_image_feature(img_path, files_list, model_path, epoch, gpu_id):
+ batch_size = args.batch_size
+ data_shape = (3, 112, 112)
+
+ files = files_list
+ print("files:", len(files))
+ rare_size = len(files) % batch_size
+ faceness_scores = []
+ batch = 0
+ img_feats = np.empty((len(files), 1024), dtype=np.float32)
+
+ batch_data = np.empty((2 * batch_size, 3, 112, 112))
+ embedding = Embedding(model_path, data_shape, batch_size)
+ for img_index, each_line in enumerate(files[: len(files) - rare_size]):
+ name_lmk_score = each_line.strip().split(" ")
+ img_name = os.path.join(img_path, name_lmk_score[0])
+ img = cv2.imread(img_name)
+ lmk = np.array([float(x) for x in name_lmk_score[1:-1]], dtype=np.float32)
+ lmk = lmk.reshape((5, 2))
+ input_blob = embedding.get(img, lmk)
+
+ batch_data[2 * (img_index - batch * batch_size)][:] = input_blob[0]
+ batch_data[2 * (img_index - batch * batch_size) + 1][:] = input_blob[1]
+ if (img_index + 1) % batch_size == 0:
+ print("batch", batch)
+ img_feats[batch * batch_size : batch * batch_size + batch_size][:] = embedding.forward_db(batch_data)
+ batch += 1
+ faceness_scores.append(name_lmk_score[-1])
+
+ batch_data = np.empty((2 * rare_size, 3, 112, 112))
+ embedding = Embedding(model_path, data_shape, rare_size)
+ for img_index, each_line in enumerate(files[len(files) - rare_size :]):
+ name_lmk_score = each_line.strip().split(" ")
+ img_name = os.path.join(img_path, name_lmk_score[0])
+ img = cv2.imread(img_name)
+ lmk = np.array([float(x) for x in name_lmk_score[1:-1]], dtype=np.float32)
+ lmk = lmk.reshape((5, 2))
+ input_blob = embedding.get(img, lmk)
+ batch_data[2 * img_index][:] = input_blob[0]
+ batch_data[2 * img_index + 1][:] = input_blob[1]
+ if (img_index + 1) % rare_size == 0:
+ print("batch", batch)
+ img_feats[len(files) - rare_size :][:] = embedding.forward_db(batch_data)
+ batch += 1
+ faceness_scores.append(name_lmk_score[-1])
+ faceness_scores = np.array(faceness_scores).astype(np.float32)
+ # img_feats = np.ones( (len(files), 1024), dtype=np.float32) * 0.01
+ # faceness_scores = np.ones( (len(files), ), dtype=np.float32 )
+ return img_feats, faceness_scores
+
+
+# In[ ]:
+
+
+def image2template_feature(img_feats=None, templates=None, medias=None):
+ # ==========================================================
+ # 1. face image feature l2 normalization. img_feats:[number_image x feats_dim]
+ # 2. compute media feature.
+ # 3. compute template feature.
+ # ==========================================================
+ unique_templates = np.unique(templates)
+ template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))
+
+ for count_template, uqt in enumerate(unique_templates):
+
+ (ind_t,) = np.where(templates == uqt)
+ face_norm_feats = img_feats[ind_t]
+ face_medias = medias[ind_t]
+ unique_medias, unique_media_counts = np.unique(face_medias, return_counts=True)
+ media_norm_feats = []
+ for u, ct in zip(unique_medias, unique_media_counts):
+ (ind_m,) = np.where(face_medias == u)
+ if ct == 1:
+ media_norm_feats += [face_norm_feats[ind_m]]
+ else: # image features from the same video will be aggregated into one feature
+ media_norm_feats += [np.mean(face_norm_feats[ind_m], axis=0, keepdims=True)]
+ media_norm_feats = np.array(media_norm_feats)
+ # media_norm_feats = media_norm_feats / np.sqrt(np.sum(media_norm_feats ** 2, -1, keepdims=True))
+ template_feats[count_template] = np.sum(media_norm_feats, axis=0)
+ if count_template % 2000 == 0:
+ print("Finish Calculating {} template features.".format(count_template))
+ # template_norm_feats = template_feats / np.sqrt(np.sum(template_feats ** 2, -1, keepdims=True))
+ template_norm_feats = sklearn.preprocessing.normalize(template_feats)
+ # print(template_norm_feats.shape)
+ return template_norm_feats, unique_templates
+
+
+# In[ ]:
+
+
+def verification(template_norm_feats=None, unique_templates=None, p1=None, p2=None):
+ # ==========================================================
+ # Compute set-to-set Similarity Score.
+ # ==========================================================
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+ for count_template, uqt in enumerate(unique_templates):
+ template2id[uqt] = count_template
+
+ score = np.zeros((len(p1),)) # save cosine distance between pairs
+
+ total_pairs = np.array(range(len(p1)))
+ batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
+ sublists = [total_pairs[i : i + batchsize] for i in range(0, len(p1), batchsize)]
+ total_sublists = len(sublists)
+ for c, s in enumerate(sublists):
+ feat1 = template_norm_feats[template2id[p1[s]]]
+ feat2 = template_norm_feats[template2id[p2[s]]]
+ similarity_score = np.sum(feat1 * feat2, -1)
+ score[s] = similarity_score.flatten()
+ if c % 10 == 0:
+ print("Finish {}/{} pairs.".format(c, total_sublists))
+ return score
+
+
+# In[ ]:
+def verification2(template_norm_feats=None, unique_templates=None, p1=None, p2=None):
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+ for count_template, uqt in enumerate(unique_templates):
+ template2id[uqt] = count_template
+ score = np.zeros((len(p1),)) # save cosine distance between pairs
+ total_pairs = np.array(range(len(p1)))
+ batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
+ sublists = [total_pairs[i : i + batchsize] for i in range(0, len(p1), batchsize)]
+ total_sublists = len(sublists)
+ for c, s in enumerate(sublists):
+ feat1 = template_norm_feats[template2id[p1[s]]]
+ feat2 = template_norm_feats[template2id[p2[s]]]
+ similarity_score = np.sum(feat1 * feat2, -1)
+ score[s] = similarity_score.flatten()
+ if c % 10 == 0:
+ print("Finish {}/{} pairs.".format(c, total_sublists))
+ return score
+
+
+def read_score(path):
+ with open(path, "rb") as fid:
+ img_feats = pickle.load(fid)
+ return img_feats
+
+
+# # Step1: Load Meta Data
+
+# In[ ]:
+
+assert target == "IJBC" or target == "IJBB"
+
+# =============================================================
+# load image and template relationships for template feature embedding
+# tid --> template id, mid --> media id
+# format:
+# image_name tid mid
+# =============================================================
+start = timeit.default_timer()
+templates, medias = read_template_media_list(
+ os.path.join("%s/meta" % image_path, "%s_face_tid_mid.txt" % target.lower())
+)
+stop = timeit.default_timer()
+print("Time: %.2f s. " % (stop - start))
+
+# In[ ]:
+
+# =============================================================
+# load template pairs for template-to-template verification
+# tid : template id, label : 1/0
+# format:
+# tid_1 tid_2 label
+# =============================================================
+start = timeit.default_timer()
+p1, p2, label = read_template_pair_list(
+ os.path.join("%s/meta" % image_path, "%s_template_pair_label.txt" % target.lower())
+)
+stop = timeit.default_timer()
+print("Time: %.2f s. " % (stop - start))
+
+# # Step 2: Get Image Features
+
+# In[ ]:
+
+# =============================================================
+# load image features
+# format:
+# img_feats: [image_num x feats_dim] (227630, 512)
+# =============================================================
+start = timeit.default_timer()
+img_path = "%s/loose_crop" % image_path
+img_list_path = "%s/meta/%s_name_5pts_score.txt" % (image_path, target.lower())
+img_list = open(img_list_path)
+files = img_list.readlines()
+# files_list = divideIntoNstrand(files, rank_size)
+files_list = files
+
+# img_feats
+# for i in range(rank_size):
+img_feats, faceness_scores = get_image_feature(img_path, files_list, model_path, 0, gpu_id)
+stop = timeit.default_timer()
+print("Time: %.2f s. " % (stop - start))
+print("Feature Shape: ({} , {}) .".format(img_feats.shape[0], img_feats.shape[1]))
+
+# # Step3: Get Template Features
+
+# In[ ]:
+
+# =============================================================
+# compute template features from image features.
+# =============================================================
+start = timeit.default_timer()
+# ==========================================================
+# Norm feature before aggregation into template feature?
+# Feature norm from embedding network and faceness score are able to decrease weights for noise samples (not face).
+# ==========================================================
+# 1. FaceScore (Feature Norm)
+# 2. FaceScore (Detector)
+
+if use_flip_test:
+ # concat --- F1
+ # img_input_feats = img_feats
+ # add --- F2
+ img_input_feats = img_feats[:, 0 : img_feats.shape[1] // 2] + img_feats[:, img_feats.shape[1] // 2 :]
+else:
+ img_input_feats = img_feats[:, 0 : img_feats.shape[1] // 2]
+
+if use_norm_score:
+ img_input_feats = img_input_feats
+else:
+ # normalise features to remove norm information
+ img_input_feats = img_input_feats / np.sqrt(np.sum(img_input_feats**2, -1, keepdims=True))
+
+if use_detector_score:
+ print(img_input_feats.shape, faceness_scores.shape)
+ img_input_feats = img_input_feats * faceness_scores[:, np.newaxis]
+else:
+ img_input_feats = img_input_feats
+
+template_norm_feats, unique_templates = image2template_feature(img_input_feats, templates, medias)
+stop = timeit.default_timer()
+print("Time: %.2f s. " % (stop - start))
+
+# # Step 4: Get Template Similarity Scores
+
+# In[ ]:
+
+# =============================================================
+# compute verification scores between template pairs.
+# =============================================================
+start = timeit.default_timer()
+score = verification(template_norm_feats, unique_templates, p1, p2)
+stop = timeit.default_timer()
+print("Time: %.2f s. " % (stop - start))
+
+# In[ ]:
+save_path = os.path.join(result_dir, args.job)
+# save_path = result_dir + '/%s_result' % target
+
+if not os.path.exists(save_path):
+ os.makedirs(save_path)
+
+score_save_file = os.path.join(save_path, "%s.npy" % target.lower())
+np.save(score_save_file, score)
+
+# # Step 5: Get ROC Curves and TPR@FPR Table
+
+# In[ ]:
+
+files = [score_save_file]
+methods = []
+scores = []
+for file in files:
+ methods.append(Path(file).stem)
+ scores.append(np.load(file))
+
+methods = np.array(methods)
+scores = dict(zip(methods, scores))
+colours = dict(zip(methods, sample_colours_from_colourmap(methods.shape[0], "Set2")))
+x_labels = [10**-6, 10**-5, 10**-4, 10**-3, 10**-2, 10**-1]
+tpr_fpr_table = PrettyTable(["Methods"] + [str(x) for x in x_labels])
+fig = plt.figure()
+for method in methods:
+ fpr, tpr, _ = roc_curve(label, scores[method])
+ roc_auc = auc(fpr, tpr)
+ fpr = np.flipud(fpr)
+ tpr = np.flipud(tpr) # select largest tpr at same fpr
+ plt.plot(
+ fpr, tpr, color=colours[method], lw=1, label=("[%s (AUC = %0.4f %%)]" % (method.split("-")[-1], roc_auc * 100))
+ )
+ tpr_fpr_row = []
+ tpr_fpr_row.append("%s-%s" % (method, target))
+ for fpr_iter in np.arange(len(x_labels)):
+ _, min_index = min(list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
+ tpr_fpr_row.append("%.2f" % (tpr[min_index] * 100))
+ tpr_fpr_table.add_row(tpr_fpr_row)
+plt.xlim([10**-6, 0.1])
+plt.ylim([0.3, 1.0])
+plt.grid(linestyle="--", linewidth=1)
+plt.xticks(x_labels)
+plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
+plt.xscale("log")
+plt.xlabel("False Positive Rate")
+plt.ylabel("True Positive Rate")
+plt.title("ROC on IJB")
+plt.legend(loc="lower right")
+fig.savefig(os.path.join(save_path, "%s.pdf" % target.lower()))
+print(tpr_fpr_table)
diff --git a/arcface_torch/flops.py b/arcface_torch/flops.py
new file mode 100644
index 0000000000000000000000000000000000000000..62aa8ec433846693a0e71e6ab808048ca37e61fd
--- /dev/null
+++ b/arcface_torch/flops.py
@@ -0,0 +1,20 @@
+import argparse
+
+from backbones import get_model
+from ptflops import get_model_complexity_info
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="")
+ parser.add_argument("n", type=str, default="r100")
+ args = parser.parse_args()
+ net = get_model(args.n)
+ macs, params = get_model_complexity_info(
+ net, (3, 112, 112), as_strings=False, print_per_layer_stat=True, verbose=True
+ )
+ gmacs = macs / (1000**3)
+ print("%.3f GFLOPs" % gmacs)
+ print("%.3f Mparams" % (params / (1000**2)))
+
+ if hasattr(net, "extra_gflops"):
+ print("%.3f Extra-GFLOPs" % net.extra_gflops)
+ print("%.3f Total-GFLOPs" % (gmacs + net.extra_gflops))
diff --git a/arcface_torch/inference.py b/arcface_torch/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..1aab06628b4f33a67284ea1446ddc7c38642c33f
--- /dev/null
+++ b/arcface_torch/inference.py
@@ -0,0 +1,34 @@
+import argparse
+
+import cv2
+import numpy as np
+import torch
+from backbones import get_model
+
+
+@torch.no_grad()
+def inference(weight, name, img):
+ if img is None:
+ img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8)
+ else:
+ img = cv2.imread(img)
+ img = cv2.resize(img, (112, 112))
+
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = np.transpose(img, (2, 0, 1))
+ img = torch.from_numpy(img).unsqueeze(0).float()
+ img.div_(255).sub_(0.5).div_(0.5)
+ net = get_model(name, fp16=False)
+ net.load_state_dict(torch.load(weight))
+ net.eval()
+ feat = net(img).numpy()
+ print(feat)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="PyTorch ArcFace Training")
+ parser.add_argument("--network", type=str, default="r50", help="backbone network")
+ parser.add_argument("--weight", type=str, default="")
+ parser.add_argument("--img", type=str, default=None)
+ args = parser.parse_args()
+ inference(args.weight, args.network, args.img)
diff --git a/arcface_torch/losses.py b/arcface_torch/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..7805d8f088e9b91f48b29d8304f87927ca65e0c4
--- /dev/null
+++ b/arcface_torch/losses.py
@@ -0,0 +1,95 @@
+import math
+
+import torch
+
+
+class CombinedMarginLoss(torch.nn.Module):
+ def __init__(self, s, m1, m2, m3, interclass_filtering_threshold=0):
+ super().__init__()
+ self.s = s
+ self.m1 = m1
+ self.m2 = m2
+ self.m3 = m3
+ self.interclass_filtering_threshold = interclass_filtering_threshold
+
+ # For ArcFace
+ self.cos_m = math.cos(self.m2)
+ self.sin_m = math.sin(self.m2)
+ self.theta = math.cos(math.pi - self.m2)
+ self.sinmm = math.sin(math.pi - self.m2) * self.m2
+ self.easy_margin = False
+
+ def forward(self, logits, labels):
+ index_positive = torch.where(labels != -1)[0]
+
+ if self.interclass_filtering_threshold > 0:
+ with torch.no_grad():
+ dirty = logits > self.interclass_filtering_threshold
+ dirty = dirty.float()
+ mask = torch.ones([index_positive.size(0), logits.size(1)], device=logits.device)
+ mask.scatter_(1, labels[index_positive], 0)
+ dirty[index_positive] *= mask
+ tensor_mul = 1 - dirty
+ logits = tensor_mul * logits
+
+ target_logit = logits[index_positive, labels[index_positive].view(-1)]
+
+ if self.m1 == 1.0 and self.m3 == 0.0:
+ with torch.no_grad():
+ target_logit.arccos_()
+ logits.arccos_()
+ final_target_logit = target_logit + self.m2
+ logits[index_positive, labels[index_positive].view(-1)] = final_target_logit
+ logits.cos_()
+ logits = logits * self.s
+
+ elif self.m3 > 0:
+ final_target_logit = target_logit - self.m3
+ logits[index_positive, labels[index_positive].view(-1)] = final_target_logit
+ logits = logits * self.s
+ else:
+ raise
+
+ return logits
+
+
+class ArcFace(torch.nn.Module):
+ """ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):"""
+
+ def __init__(self, s=64.0, margin=0.5):
+ super(ArcFace, self).__init__()
+ self.scale = s
+ self.margin = margin
+ self.cos_m = math.cos(margin)
+ self.sin_m = math.sin(margin)
+ self.theta = math.cos(math.pi - margin)
+ self.sinmm = math.sin(math.pi - margin) * margin
+ self.easy_margin = False
+
+ def forward(self, logits: torch.Tensor, labels: torch.Tensor):
+ index = torch.where(labels != -1)[0]
+ target_logit = logits[index, labels[index].view(-1)]
+
+ with torch.no_grad():
+ target_logit.arccos_()
+ logits.arccos_()
+ final_target_logit = target_logit + self.margin
+ logits[index, labels[index].view(-1)] = final_target_logit
+ logits.cos_()
+ logits = logits * self.s
+ return logits
+
+
+class CosFace(torch.nn.Module):
+ def __init__(self, s=64.0, m=0.40):
+ super(CosFace, self).__init__()
+ self.s = s
+ self.m = m
+
+ def forward(self, logits: torch.Tensor, labels: torch.Tensor):
+ index = torch.where(labels != -1)[0]
+ target_logit = logits[index, labels[index].view(-1)]
+ final_target_logit = target_logit - self.m
+ logits[index, labels[index].view(-1)] = final_target_logit
+ logits = logits * self.s
+ return logits
diff --git a/arcface_torch/lr_scheduler.py b/arcface_torch/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..3020ff343d3333d18cdf9102d6c66be29bab33fa
--- /dev/null
+++ b/arcface_torch/lr_scheduler.py
@@ -0,0 +1,28 @@
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+class PolyScheduler(_LRScheduler):
+ def __init__(self, optimizer, base_lr, max_steps, warmup_steps, last_epoch=-1):
+ self.base_lr = base_lr
+ self.warmup_lr_init = 0.0001
+ self.max_steps: int = max_steps
+ self.warmup_steps: int = warmup_steps
+ self.power = 2
+ super(PolyScheduler, self).__init__(optimizer, -1, False)
+ self.last_epoch = last_epoch
+
+ def get_warmup_lr(self):
+ alpha = float(self.last_epoch) / float(self.warmup_steps)
+ return [self.base_lr * alpha for _ in self.optimizer.param_groups]
+
+ def get_lr(self):
+ if self.last_epoch == -1:
+ return [self.warmup_lr_init for _ in self.optimizer.param_groups]
+ if self.last_epoch < self.warmup_steps:
+ return self.get_warmup_lr()
+ else:
+ alpha = pow(
+ 1 - float(self.last_epoch - self.warmup_steps) / float(self.max_steps - self.warmup_steps),
+ self.power,
+ )
+ return [self.base_lr * alpha for _ in self.optimizer.param_groups]
diff --git a/arcface_torch/onnx_helper.py b/arcface_torch/onnx_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..95f615fd7f3e0586be123d9a6538f68386158360
--- /dev/null
+++ b/arcface_torch/onnx_helper.py
@@ -0,0 +1,264 @@
+from __future__ import division
+
+import argparse
+import datetime
+import glob
+import os
+import os.path as osp
+import sys
+
+import cv2
+import numpy as np
+import onnx
+import onnxruntime
+from insightface.data import get_image
+from onnx import numpy_helper
+
+
+class ArcFaceORT:
+ def __init__(self, model_path, cpu=False):
+ self.model_path = model_path
+ # providers = None will use available provider, for onnxruntime-gpu it will be "CUDAExecutionProvider"
+ self.providers = ["CPUExecutionProvider"] if cpu else None
+
+ # input_size is (w,h), return error message, return None if success
+ def check(self, track="cfat", test_img=None):
+ # default is cfat
+ max_model_size_mb = 1024
+ max_feat_dim = 512
+ max_time_cost = 15
+ if track.startswith("ms1m"):
+ max_model_size_mb = 1024
+ max_feat_dim = 512
+ max_time_cost = 10
+ elif track.startswith("glint"):
+ max_model_size_mb = 1024
+ max_feat_dim = 1024
+ max_time_cost = 20
+ elif track.startswith("cfat"):
+ max_model_size_mb = 1024
+ max_feat_dim = 512
+ max_time_cost = 15
+ elif track.startswith("unconstrained"):
+ max_model_size_mb = 1024
+ max_feat_dim = 1024
+ max_time_cost = 30
+ else:
+ return "track not found"
+
+ if not os.path.exists(self.model_path):
+ return "model_path not exists"
+ if not os.path.isdir(self.model_path):
+ return "model_path should be directory"
+ onnx_files = []
+ for _file in os.listdir(self.model_path):
+ if _file.endswith(".onnx"):
+ onnx_files.append(osp.join(self.model_path, _file))
+ if len(onnx_files) == 0:
+ return "do not have onnx files"
+ self.model_file = sorted(onnx_files)[-1]
+ print("use onnx-model:", self.model_file)
+ try:
+ session = onnxruntime.InferenceSession(self.model_file, providers=self.providers)
+ except:
+ return "load onnx failed"
+ input_cfg = session.get_inputs()[0]
+ input_shape = input_cfg.shape
+ print("input-shape:", input_shape)
+ if len(input_shape) != 4:
+ return "length of input_shape should be 4"
+ if not isinstance(input_shape[0], str):
+ # return "input_shape[0] should be str to support batch-inference"
+ print("reset input-shape[0] to None")
+ model = onnx.load(self.model_file)
+ model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = "None"
+ new_model_file = osp.join(self.model_path, "zzzzrefined.onnx")
+ onnx.save(model, new_model_file)
+ self.model_file = new_model_file
+ print("use new onnx-model:", self.model_file)
+ try:
+ session = onnxruntime.InferenceSession(self.model_file, providers=self.providers)
+ except:
+ return "load onnx failed"
+ input_cfg = session.get_inputs()[0]
+ input_shape = input_cfg.shape
+ print("new-input-shape:", input_shape)
+
+ self.image_size = tuple(input_shape[2:4][::-1])
+ # print('image_size:', self.image_size)
+ input_name = input_cfg.name
+ outputs = session.get_outputs()
+ output_names = []
+ for o in outputs:
+ output_names.append(o.name)
+ # print(o.name, o.shape)
+ if len(output_names) != 1:
+ return "number of output nodes should be 1"
+ self.session = session
+ self.input_name = input_name
+ self.output_names = output_names
+ # print(self.output_names)
+ model = onnx.load(self.model_file)
+ graph = model.graph
+ if len(graph.node) < 8:
+ return "too small onnx graph"
+
+ input_size = (112, 112)
+ self.crop = None
+ if track == "cfat":
+ crop_file = osp.join(self.model_path, "crop.txt")
+ if osp.exists(crop_file):
+ lines = open(crop_file, "r").readlines()
+ if len(lines) != 6:
+ return "crop.txt should contain 6 lines"
+ lines = [int(x) for x in lines]
+ self.crop = lines[:4]
+ input_size = tuple(lines[4:6])
+ if input_size != self.image_size:
+ return "input-size is inconsistant with onnx model input, %s vs %s" % (input_size, self.image_size)
+
+ self.model_size_mb = os.path.getsize(self.model_file) / float(1024 * 1024)
+ if self.model_size_mb > max_model_size_mb:
+ return "max model size exceed, given %.3f-MB" % self.model_size_mb
+
+ input_mean = None
+ input_std = None
+ if track == "cfat":
+ pn_file = osp.join(self.model_path, "pixel_norm.txt")
+ if osp.exists(pn_file):
+ lines = open(pn_file, "r").readlines()
+ if len(lines) != 2:
+ return "pixel_norm.txt should contain 2 lines"
+ input_mean = float(lines[0])
+ input_std = float(lines[1])
+ if input_mean is not None or input_std is not None:
+ if input_mean is None or input_std is None:
+ return "please set input_mean and input_std simultaneously"
+ else:
+ find_sub = False
+ find_mul = False
+ for nid, node in enumerate(graph.node[:8]):
+ print(nid, node.name)
+ if node.name.startswith("Sub") or node.name.startswith("_minus"):
+ find_sub = True
+ if node.name.startswith("Mul") or node.name.startswith("_mul") or node.name.startswith("Div"):
+ find_mul = True
+ if find_sub and find_mul:
+ print("find sub and mul")
+ # mxnet arcface model
+ input_mean = 0.0
+ input_std = 1.0
+ else:
+ input_mean = 127.5
+ input_std = 127.5
+ self.input_mean = input_mean
+ self.input_std = input_std
+ for initn in graph.initializer:
+ weight_array = numpy_helper.to_array(initn)
+ dt = weight_array.dtype
+ if dt.itemsize < 4:
+ return "invalid weight type - (%s:%s)" % (initn.name, dt.name)
+ if test_img is None:
+ test_img = get_image("Tom_Hanks_54745")
+ test_img = cv2.resize(test_img, self.image_size)
+ else:
+ test_img = cv2.resize(test_img, self.image_size)
+ feat, cost = self.benchmark(test_img)
+ batch_result = self.check_batch(test_img)
+ batch_result_sum = float(np.sum(batch_result))
+ if batch_result_sum in [float("inf"), -float("inf")] or batch_result_sum != batch_result_sum:
+ print(batch_result)
+ print(batch_result_sum)
+ return "batch result output contains NaN!"
+
+ if len(feat.shape) < 2:
+ return "the shape of the feature must be two, but get {}".format(str(feat.shape))
+
+ if feat.shape[1] > max_feat_dim:
+ return "max feat dim exceed, given %d" % feat.shape[1]
+ self.feat_dim = feat.shape[1]
+ cost_ms = cost * 1000
+ if cost_ms > max_time_cost:
+ return "max time cost exceed, given %.4f" % cost_ms
+ self.cost_ms = cost_ms
+ print(
+ "check stat:, model-size-mb: %.4f, feat-dim: %d, time-cost-ms: %.4f, input-mean: %.3f, input-std: %.3f"
+ % (self.model_size_mb, self.feat_dim, self.cost_ms, self.input_mean, self.input_std)
+ )
+ return None
+
+ def check_batch(self, img):
+ if not isinstance(img, list):
+ imgs = [
+ img,
+ ] * 32
+ if self.crop is not None:
+ nimgs = []
+ for img in imgs:
+ nimg = img[self.crop[1] : self.crop[3], self.crop[0] : self.crop[2], :]
+ if nimg.shape[0] != self.image_size[1] or nimg.shape[1] != self.image_size[0]:
+ nimg = cv2.resize(nimg, self.image_size)
+ nimgs.append(nimg)
+ imgs = nimgs
+ blob = cv2.dnn.blobFromImages(
+ images=imgs,
+ scalefactor=1.0 / self.input_std,
+ size=self.image_size,
+ mean=(self.input_mean, self.input_mean, self.input_mean),
+ swapRB=True,
+ )
+ net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
+ return net_out
+
+ def meta_info(self):
+ return {"model-size-mb": self.model_size_mb, "feature-dim": self.feat_dim, "infer": self.cost_ms}
+
+ def forward(self, imgs):
+ if not isinstance(imgs, list):
+ imgs = [imgs]
+ input_size = self.image_size
+ if self.crop is not None:
+ nimgs = []
+ for img in imgs:
+ nimg = img[self.crop[1] : self.crop[3], self.crop[0] : self.crop[2], :]
+ if nimg.shape[0] != input_size[1] or nimg.shape[1] != input_size[0]:
+ nimg = cv2.resize(nimg, input_size)
+ nimgs.append(nimg)
+ imgs = nimgs
+ blob = cv2.dnn.blobFromImages(
+ imgs, 1.0 / self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True
+ )
+ net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
+ return net_out
+
+ def benchmark(self, img):
+ input_size = self.image_size
+ if self.crop is not None:
+ nimg = img[self.crop[1] : self.crop[3], self.crop[0] : self.crop[2], :]
+ if nimg.shape[0] != input_size[1] or nimg.shape[1] != input_size[0]:
+ nimg = cv2.resize(nimg, input_size)
+ img = nimg
+ blob = cv2.dnn.blobFromImage(
+ img, 1.0 / self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True
+ )
+ costs = []
+ for _ in range(50):
+ ta = datetime.datetime.now()
+ net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
+ tb = datetime.datetime.now()
+ cost = (tb - ta).total_seconds()
+ costs.append(cost)
+ costs = sorted(costs)
+ cost = costs[5]
+ return net_out, cost
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="")
+ # general
+ parser.add_argument("workdir", help="submitted work dir", type=str)
+ parser.add_argument("--track", help="track name, for different challenge", type=str, default="cfat")
+ args = parser.parse_args()
+ handler = ArcFaceORT(args.workdir)
+ err = handler.check(args.track)
+ print("err:", err)
diff --git a/arcface_torch/onnx_ijbc.py b/arcface_torch/onnx_ijbc.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d1bff03699c96139f3f9e9b52998cba592d9d72
--- /dev/null
+++ b/arcface_torch/onnx_ijbc.py
@@ -0,0 +1,262 @@
+import argparse
+import os
+import pickle
+import timeit
+
+import cv2
+import mxnet as mx
+import numpy as np
+import pandas as pd
+import prettytable
+import skimage.transform
+import torch
+from onnx_helper import ArcFaceORT
+from sklearn.metrics import roc_curve
+from sklearn.preprocessing import normalize
+from torch.utils.data import DataLoader
+
+SRC = np.array(
+ [[30.2946, 51.6963], [65.5318, 51.5014], [48.0252, 71.7366], [33.5493, 92.3655], [62.7299, 92.2041]],
+ dtype=np.float32,
+)
+SRC[:, 0] += 8.0
+
+
+@torch.no_grad()
+class AlignedDataSet(mx.gluon.data.Dataset):
+ def __init__(self, root, lines, align=True):
+ self.lines = lines
+ self.root = root
+ self.align = align
+
+ def __len__(self):
+ return len(self.lines)
+
+ def __getitem__(self, idx):
+ each_line = self.lines[idx]
+ name_lmk_score = each_line.strip().split(" ")
+ name = os.path.join(self.root, name_lmk_score[0])
+ img = cv2.cvtColor(cv2.imread(name), cv2.COLOR_BGR2RGB)
+ landmark5 = np.array([float(x) for x in name_lmk_score[1:-1]], dtype=np.float32).reshape((5, 2))
+ st = skimage.transform.SimilarityTransform()
+ st.estimate(landmark5, SRC)
+ img = cv2.warpAffine(img, st.params[0:2, :], (112, 112), borderValue=0.0)
+ img_1 = np.expand_dims(img, 0)
+ img_2 = np.expand_dims(np.fliplr(img), 0)
+ output = np.concatenate((img_1, img_2), axis=0).astype(np.float32)
+ output = np.transpose(output, (0, 3, 1, 2))
+ return torch.from_numpy(output)
+
+
+@torch.no_grad()
+def extract(model_root, dataset):
+ model = ArcFaceORT(model_path=model_root)
+ model.check()
+ feat_mat = np.zeros(shape=(len(dataset), 2 * model.feat_dim))
+
+ def collate_fn(data):
+ return torch.cat(data, dim=0)
+
+ data_loader = DataLoader(
+ dataset,
+ batch_size=128,
+ drop_last=False,
+ num_workers=4,
+ collate_fn=collate_fn,
+ )
+ num_iter = 0
+ for batch in data_loader:
+ batch = batch.numpy()
+ batch = (batch - model.input_mean) / model.input_std
+ feat = model.session.run(model.output_names, {model.input_name: batch})[0]
+ feat = np.reshape(feat, (-1, model.feat_dim * 2))
+ feat_mat[128 * num_iter : 128 * num_iter + feat.shape[0], :] = feat
+ num_iter += 1
+ if num_iter % 50 == 0:
+ print(num_iter)
+ return feat_mat
+
+
+def read_template_media_list(path):
+ ijb_meta = pd.read_csv(path, sep=" ", header=None).values
+ templates = ijb_meta[:, 1].astype(np.int)
+ medias = ijb_meta[:, 2].astype(np.int)
+ return templates, medias
+
+
+def read_template_pair_list(path):
+ pairs = pd.read_csv(path, sep=" ", header=None).values
+ t1 = pairs[:, 0].astype(np.int)
+ t2 = pairs[:, 1].astype(np.int)
+ label = pairs[:, 2].astype(np.int)
+ return t1, t2, label
+
+
+def read_image_feature(path):
+ with open(path, "rb") as fid:
+ img_feats = pickle.load(fid)
+ return img_feats
+
+
+def image2template_feature(img_feats=None, templates=None, medias=None):
+ unique_templates = np.unique(templates)
+ template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))
+ for count_template, uqt in enumerate(unique_templates):
+ (ind_t,) = np.where(templates == uqt)
+ face_norm_feats = img_feats[ind_t]
+ face_medias = medias[ind_t]
+ unique_medias, unique_media_counts = np.unique(face_medias, return_counts=True)
+ media_norm_feats = []
+ for u, ct in zip(unique_medias, unique_media_counts):
+ (ind_m,) = np.where(face_medias == u)
+ if ct == 1:
+ media_norm_feats += [face_norm_feats[ind_m]]
+ else: # image features from the same video will be aggregated into one feature
+ media_norm_feats += [
+ np.mean(face_norm_feats[ind_m], axis=0, keepdims=True),
+ ]
+ media_norm_feats = np.array(media_norm_feats)
+ template_feats[count_template] = np.sum(media_norm_feats, axis=0)
+ if count_template % 2000 == 0:
+ print("Finish Calculating {} template features.".format(count_template))
+ template_norm_feats = normalize(template_feats)
+ return template_norm_feats, unique_templates
+
+
+def verification(template_norm_feats=None, unique_templates=None, p1=None, p2=None):
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+ for count_template, uqt in enumerate(unique_templates):
+ template2id[uqt] = count_template
+ score = np.zeros((len(p1),))
+ total_pairs = np.array(range(len(p1)))
+ batchsize = 100000
+ sublists = [total_pairs[i : i + batchsize] for i in range(0, len(p1), batchsize)]
+ total_sublists = len(sublists)
+ for c, s in enumerate(sublists):
+ feat1 = template_norm_feats[template2id[p1[s]]]
+ feat2 = template_norm_feats[template2id[p2[s]]]
+ similarity_score = np.sum(feat1 * feat2, -1)
+ score[s] = similarity_score.flatten()
+ if c % 10 == 0:
+ print("Finish {}/{} pairs.".format(c, total_sublists))
+ return score
+
+
+def verification2(template_norm_feats=None, unique_templates=None, p1=None, p2=None):
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+ for count_template, uqt in enumerate(unique_templates):
+ template2id[uqt] = count_template
+ score = np.zeros((len(p1),)) # save cosine distance between pairs
+ total_pairs = np.array(range(len(p1)))
+ batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
+ sublists = [total_pairs[i : i + batchsize] for i in range(0, len(p1), batchsize)]
+ total_sublists = len(sublists)
+ for c, s in enumerate(sublists):
+ feat1 = template_norm_feats[template2id[p1[s]]]
+ feat2 = template_norm_feats[template2id[p2[s]]]
+ similarity_score = np.sum(feat1 * feat2, -1)
+ score[s] = similarity_score.flatten()
+ if c % 10 == 0:
+ print("Finish {}/{} pairs.".format(c, total_sublists))
+ return score
+
+
+def main(args):
+ use_norm_score = True # if Ture, TestMode(N1)
+ use_detector_score = True # if Ture, TestMode(D1)
+ use_flip_test = True # if Ture, TestMode(F1)
+ assert args.target == "IJBC" or args.target == "IJBB"
+
+ start = timeit.default_timer()
+ templates, medias = read_template_media_list(
+ os.path.join("%s/meta" % args.image_path, "%s_face_tid_mid.txt" % args.target.lower())
+ )
+ stop = timeit.default_timer()
+ print("Time: %.2f s. " % (stop - start))
+
+ start = timeit.default_timer()
+ p1, p2, label = read_template_pair_list(
+ os.path.join("%s/meta" % args.image_path, "%s_template_pair_label.txt" % args.target.lower())
+ )
+ stop = timeit.default_timer()
+ print("Time: %.2f s. " % (stop - start))
+
+ start = timeit.default_timer()
+ img_path = "%s/loose_crop" % args.image_path
+ img_list_path = "%s/meta/%s_name_5pts_score.txt" % (args.image_path, args.target.lower())
+ img_list = open(img_list_path)
+ files = img_list.readlines()
+ dataset = AlignedDataSet(root=img_path, lines=files, align=True)
+ img_feats = extract(args.model_root, dataset)
+
+ faceness_scores = []
+ for each_line in files:
+ name_lmk_score = each_line.split()
+ faceness_scores.append(name_lmk_score[-1])
+ faceness_scores = np.array(faceness_scores).astype(np.float32)
+ stop = timeit.default_timer()
+ print("Time: %.2f s. " % (stop - start))
+ print("Feature Shape: ({} , {}) .".format(img_feats.shape[0], img_feats.shape[1]))
+ start = timeit.default_timer()
+
+ if use_flip_test:
+ img_input_feats = img_feats[:, 0 : img_feats.shape[1] // 2] + img_feats[:, img_feats.shape[1] // 2 :]
+ else:
+ img_input_feats = img_feats[:, 0 : img_feats.shape[1] // 2]
+
+ if use_norm_score:
+ img_input_feats = img_input_feats
+ else:
+ img_input_feats = img_input_feats / np.sqrt(np.sum(img_input_feats**2, -1, keepdims=True))
+
+ if use_detector_score:
+ print(img_input_feats.shape, faceness_scores.shape)
+ img_input_feats = img_input_feats * faceness_scores[:, np.newaxis]
+ else:
+ img_input_feats = img_input_feats
+
+ template_norm_feats, unique_templates = image2template_feature(img_input_feats, templates, medias)
+ stop = timeit.default_timer()
+ print("Time: %.2f s. " % (stop - start))
+
+ start = timeit.default_timer()
+ score = verification(template_norm_feats, unique_templates, p1, p2)
+ stop = timeit.default_timer()
+ print("Time: %.2f s. " % (stop - start))
+ result_dir = args.model_root
+
+ save_path = os.path.join(result_dir, "{}_result".format(args.target))
+ if not os.path.exists(save_path):
+ os.makedirs(save_path)
+ score_save_file = os.path.join(save_path, "{}.npy".format(args.target))
+ np.save(score_save_file, score)
+ files = [score_save_file]
+ methods = []
+ scores = []
+ for file in files:
+ methods.append(os.path.basename(file))
+ scores.append(np.load(file))
+ methods = np.array(methods)
+ scores = dict(zip(methods, scores))
+ x_labels = [10**-6, 10**-5, 10**-4, 10**-3, 10**-2, 10**-1]
+ tpr_fpr_table = prettytable.PrettyTable(["Methods"] + [str(x) for x in x_labels])
+ for method in methods:
+ fpr, tpr, _ = roc_curve(label, scores[method])
+ fpr = np.flipud(fpr)
+ tpr = np.flipud(tpr)
+ tpr_fpr_row = []
+ tpr_fpr_row.append("%s-%s" % (method, args.target))
+ for fpr_iter in np.arange(len(x_labels)):
+ _, min_index = min(list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
+ tpr_fpr_row.append("%.2f" % (tpr[min_index] * 100))
+ tpr_fpr_table.add_row(tpr_fpr_row)
+ print(tpr_fpr_table)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="do ijb test")
+ # general
+ parser.add_argument("--model-root", default="", help="path to load model.")
+ parser.add_argument("--image-path", default="/train_tmp/IJB_release/IJBC", type=str, help="")
+ parser.add_argument("--target", default="IJBC", type=str, help="target, set to IJBC or IJBB")
+ main(parser.parse_args())
diff --git a/arcface_torch/partial_fc.py b/arcface_torch/partial_fc.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7891527d6c396a6b51a67daf06593d4db5cce43
--- /dev/null
+++ b/arcface_torch/partial_fc.py
@@ -0,0 +1,490 @@
+import collections
+from typing import Callable
+
+import torch
+from torch import distributed
+from torch.nn.functional import linear
+from torch.nn.functional import normalize
+
+
+class PartialFC(torch.nn.Module):
+ """
+ https://arxiv.org/abs/2203.15565
+ A distributed sparsely updating variant of the FC layer, named Partial FC (PFC).
+
+ When sample rate less than 1, in each iteration, positive class centers and a random subset of
+ negative class centers are selected to compute the margin-based softmax loss, all class
+ centers are still maintained throughout the whole training process, but only a subset is
+ selected and updated in each iteration.
+
+ .. note::
+ When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1).
+
+ Example:
+ --------
+ >>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2)
+ >>> for img, labels in data_loader:
+ >>> embeddings = net(img)
+ >>> loss = module_pfc(embeddings, labels, optimizer)
+ >>> loss.backward()
+ >>> optimizer.step()
+ """
+
+ _version = 1
+
+ def __init__(
+ self,
+ margin_loss: Callable,
+ embedding_size: int,
+ num_classes: int,
+ sample_rate: float = 1.0,
+ fp16: bool = False,
+ ):
+ """
+ Paramenters:
+ -----------
+ embedding_size: int
+ The dimension of embedding, required
+ num_classes: int
+ Total number of classes, required
+ sample_rate: float
+ The rate of negative centers participating in the calculation, default is 1.0.
+ """
+ super(PartialFC, self).__init__()
+ assert distributed.is_initialized(), "must initialize distributed before create this"
+ self.rank = distributed.get_rank()
+ self.world_size = distributed.get_world_size()
+
+ self.dist_cross_entropy = DistCrossEntropy()
+ self.embedding_size = embedding_size
+ self.sample_rate: float = sample_rate
+ self.fp16 = fp16
+ self.num_local: int = num_classes // self.world_size + int(self.rank < num_classes % self.world_size)
+ self.class_start: int = num_classes // self.world_size * self.rank + min(
+ self.rank, num_classes % self.world_size
+ )
+ self.num_sample: int = int(self.sample_rate * self.num_local)
+ self.last_batch_size: int = 0
+ self.weight: torch.Tensor
+ self.weight_mom: torch.Tensor
+ self.weight_activated: torch.nn.Parameter
+ self.weight_activated_mom: torch.Tensor
+ self.is_updated: bool = True
+ self.init_weight_update: bool = True
+
+ if self.sample_rate < 1:
+ self.register_buffer("weight", tensor=torch.normal(0, 0.01, (self.num_local, embedding_size)))
+ self.register_buffer("weight_mom", tensor=torch.zeros_like(self.weight))
+ self.register_parameter("weight_activated", param=torch.nn.Parameter(torch.empty(0, 0)))
+ self.register_buffer("weight_activated_mom", tensor=torch.empty(0, 0))
+ self.register_buffer("weight_index", tensor=torch.empty(0, 0))
+ else:
+ self.weight_activated = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size)))
+
+ # margin_loss
+ if isinstance(margin_loss, Callable):
+ self.margin_softmax = margin_loss
+ else:
+ raise
+
+ @torch.no_grad()
+ def sample(self, labels: torch.Tensor, index_positive: torch.Tensor, optimizer: torch.optim.Optimizer):
+ """
+ This functions will change the value of labels
+
+ Parameters:
+ -----------
+ labels: torch.Tensor
+ pass
+ index_positive: torch.Tensor
+ pass
+ optimizer: torch.optim.Optimizer
+ pass
+ """
+ positive = torch.unique(labels[index_positive], sorted=True).cuda()
+ if self.num_sample - positive.size(0) >= 0:
+ perm = torch.rand(size=[self.num_local]).cuda()
+ perm[positive] = 2.0
+ index = torch.topk(perm, k=self.num_sample)[1].cuda()
+ index = index.sort()[0].cuda()
+ else:
+ index = positive
+ self.weight_index = index
+
+ labels[index_positive] = torch.searchsorted(index, labels[index_positive])
+
+ self.weight_activated = torch.nn.Parameter(self.weight[self.weight_index])
+ self.weight_activated_mom = self.weight_mom[self.weight_index]
+
+ if isinstance(optimizer, torch.optim.SGD):
+ # TODO the params of partial fc must be last in the params list
+ optimizer.state.pop(optimizer.param_groups[-1]["params"][0], None)
+ optimizer.param_groups[-1]["params"][0] = self.weight_activated
+ optimizer.state[self.weight_activated]["momentum_buffer"] = self.weight_activated_mom
+ else:
+ raise
+
+ @torch.no_grad()
+ def update(self):
+ """partial weight to global"""
+ if self.init_weight_update:
+ self.init_weight_update = False
+ return
+
+ if self.sample_rate < 1:
+ self.weight[self.weight_index] = self.weight_activated
+ self.weight_mom[self.weight_index] = self.weight_activated_mom
+
+ def forward(
+ self,
+ local_embeddings: torch.Tensor,
+ local_labels: torch.Tensor,
+ optimizer: torch.optim.Optimizer,
+ ):
+ """
+ Parameters:
+ ----------
+ local_embeddings: torch.Tensor
+ feature embeddings on each GPU(Rank).
+ local_labels: torch.Tensor
+ labels on each GPU(Rank).
+
+ Returns:
+ -------
+ loss: torch.Tensor
+ pass
+ """
+ local_labels.squeeze_()
+ local_labels = local_labels.long()
+ self.update()
+
+ batch_size = local_embeddings.size(0)
+ if self.last_batch_size == 0:
+ self.last_batch_size = batch_size
+ assert self.last_batch_size == batch_size, "last batch size do not equal current batch size: {} vs {}".format(
+ self.last_batch_size, batch_size
+ )
+
+ _gather_embeddings = [torch.zeros((batch_size, self.embedding_size)).cuda() for _ in range(self.world_size)]
+ _gather_labels = [torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)]
+ _list_embeddings = AllGather(local_embeddings, *_gather_embeddings)
+ distributed.all_gather(_gather_labels, local_labels)
+
+ embeddings = torch.cat(_list_embeddings)
+ labels = torch.cat(_gather_labels)
+
+ labels = labels.view(-1, 1)
+ index_positive = (self.class_start <= labels) & (labels < self.class_start + self.num_local)
+ labels[~index_positive] = -1
+ labels[index_positive] -= self.class_start
+
+ if self.sample_rate < 1:
+ self.sample(labels, index_positive, optimizer)
+
+ with torch.cuda.amp.autocast(self.fp16):
+ norm_embeddings = normalize(embeddings)
+ norm_weight_activated = normalize(self.weight_activated)
+ logits = linear(norm_embeddings, norm_weight_activated)
+ if self.fp16:
+ logits = logits.float()
+ logits = logits.clamp(-1, 1)
+
+ logits = self.margin_softmax(logits, labels)
+ loss = self.dist_cross_entropy(logits, labels)
+ return loss
+
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
+ if destination is None:
+ destination = collections.OrderedDict()
+ destination._metadata = collections.OrderedDict()
+
+ for name, module in self._modules.items():
+ if module is not None:
+ module.state_dict(destination, prefix + name + ".", keep_vars=keep_vars)
+ if self.sample_rate < 1:
+ destination["weight"] = self.weight.detach()
+ else:
+ destination["weight"] = self.weight_activated.data.detach()
+ return destination
+
+ def load_state_dict(self, state_dict, strict: bool = True):
+ if self.sample_rate < 1:
+ self.weight = state_dict["weight"].to(self.weight.device)
+ self.weight_mom.zero_()
+ self.weight_activated.data.zero_()
+ self.weight_activated_mom.zero_()
+ self.weight_index.zero_()
+ else:
+ self.weight_activated.data = state_dict["weight"].to(self.weight_activated.data.device)
+
+
+class PartialFCAdamW(torch.nn.Module):
+ def __init__(
+ self,
+ margin_loss: Callable,
+ embedding_size: int,
+ num_classes: int,
+ sample_rate: float = 1.0,
+ fp16: bool = False,
+ ):
+ """
+ Paramenters:
+ -----------
+ embedding_size: int
+ The dimension of embedding, required
+ num_classes: int
+ Total number of classes, required
+ sample_rate: float
+ The rate of negative centers participating in the calculation, default is 1.0.
+ """
+ super(PartialFCAdamW, self).__init__()
+ assert distributed.is_initialized(), "must initialize distributed before create this"
+ self.rank = distributed.get_rank()
+ self.world_size = distributed.get_world_size()
+
+ self.dist_cross_entropy = DistCrossEntropy()
+ self.embedding_size = embedding_size
+ self.sample_rate: float = sample_rate
+ self.fp16 = fp16
+ self.num_local: int = num_classes // self.world_size + int(self.rank < num_classes % self.world_size)
+ self.class_start: int = num_classes // self.world_size * self.rank + min(
+ self.rank, num_classes % self.world_size
+ )
+ self.num_sample: int = int(self.sample_rate * self.num_local)
+ self.last_batch_size: int = 0
+ self.weight: torch.Tensor
+ self.weight_exp_avg: torch.Tensor
+ self.weight_exp_avg_sq: torch.Tensor
+ self.weight_activated: torch.nn.Parameter
+ self.weight_activated_exp_avg: torch.Tensor
+ self.weight_activated_exp_avg_sq: torch.Tensor
+
+ self.is_updated: bool = True
+ self.init_weight_update: bool = True
+
+ if self.sample_rate < 1:
+ self.register_buffer("weight", tensor=torch.normal(0, 0.01, (self.num_local, embedding_size)))
+ self.register_buffer("weight_exp_avg", tensor=torch.zeros_like(self.weight))
+ self.register_buffer("weight_exp_avg_sq", tensor=torch.zeros_like(self.weight))
+ self.register_parameter("weight_activated", param=torch.nn.Parameter(torch.empty(0, 0)))
+ self.register_buffer("weight_activated_exp_avg", tensor=torch.empty(0, 0))
+ self.register_buffer("weight_activated_exp_avg_sq", tensor=torch.empty(0, 0))
+ else:
+ self.weight_activated = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size)))
+ self.step = 0
+
+ if isinstance(margin_loss, Callable):
+ self.margin_softmax = margin_loss
+ else:
+ raise
+
+ @torch.no_grad()
+ def sample(self, labels, index_positive, optimizer):
+ self.step += 1
+ positive = torch.unique(labels[index_positive], sorted=True).cuda()
+ if self.num_sample - positive.size(0) >= 0:
+ perm = torch.rand(size=[self.num_local]).cuda()
+ perm[positive] = 2.0
+ index = torch.topk(perm, k=self.num_sample)[1].cuda()
+ index = index.sort()[0].cuda()
+ else:
+ index = positive
+ self.weight_index = index
+ labels[index_positive] = torch.searchsorted(index, labels[index_positive])
+ self.weight_activated = torch.nn.Parameter(self.weight[self.weight_index])
+ self.weight_activated_exp_avg = self.weight_exp_avg[self.weight_index]
+ self.weight_activated_exp_avg_sq = self.weight_exp_avg_sq[self.weight_index]
+
+ if isinstance(optimizer, (torch.optim.Adam, torch.optim.AdamW)):
+ # TODO the params of partial fc must be last in the params list
+ optimizer.state.pop(optimizer.param_groups[-1]["params"][0], None)
+ optimizer.param_groups[-1]["params"][0] = self.weight_activated
+ optimizer.state[self.weight_activated]["exp_avg"] = self.weight_activated_exp_avg
+ optimizer.state[self.weight_activated]["exp_avg_sq"] = self.weight_activated_exp_avg_sq
+ optimizer.state[self.weight_activated]["step"] = self.step
+ else:
+ raise
+
+ @torch.no_grad()
+ def update(self):
+ """partial weight to global"""
+ if self.init_weight_update:
+ self.init_weight_update = False
+ return
+
+ if self.sample_rate < 1:
+ self.weight[self.weight_index] = self.weight_activated
+ self.weight_exp_avg[self.weight_index] = self.weight_activated_exp_avg
+ self.weight_exp_avg_sq[self.weight_index] = self.weight_activated_exp_avg_sq
+
+ def forward(
+ self,
+ local_embeddings: torch.Tensor,
+ local_labels: torch.Tensor,
+ optimizer: torch.optim.Optimizer,
+ ):
+ """
+ Parameters:
+ ----------
+ local_embeddings: torch.Tensor
+ feature embeddings on each GPU(Rank).
+ local_labels: torch.Tensor
+ labels on each GPU(Rank).
+
+ Returns:
+ -------
+ loss: torch.Tensor
+ pass
+ """
+ local_labels.squeeze_()
+ local_labels = local_labels.long()
+ self.update()
+
+ batch_size = local_embeddings.size(0)
+ if self.last_batch_size == 0:
+ self.last_batch_size = batch_size
+ assert self.last_batch_size == batch_size, "last batch size do not equal current batch size: {} vs {}".format(
+ self.last_batch_size, batch_size
+ )
+
+ _gather_embeddings = [torch.zeros((batch_size, self.embedding_size)).cuda() for _ in range(self.world_size)]
+ _gather_labels = [torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)]
+ _list_embeddings = AllGather(local_embeddings, *_gather_embeddings)
+ distributed.all_gather(_gather_labels, local_labels)
+
+ embeddings = torch.cat(_list_embeddings)
+ labels = torch.cat(_gather_labels)
+
+ labels = labels.view(-1, 1)
+ index_positive = (self.class_start <= labels) & (labels < self.class_start + self.num_local)
+ labels[~index_positive] = -1
+ labels[index_positive] -= self.class_start
+
+ if self.sample_rate < 1:
+ self.sample(labels, index_positive, optimizer)
+
+ with torch.cuda.amp.autocast(self.fp16):
+ norm_embeddings = normalize(embeddings)
+ norm_weight_activated = normalize(self.weight_activated)
+ logits = linear(norm_embeddings, norm_weight_activated)
+ if self.fp16:
+ logits = logits.float()
+ logits = logits.clamp(-1, 1)
+
+ logits = self.margin_softmax(logits, labels)
+ loss = self.dist_cross_entropy(logits, labels)
+ return loss
+
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
+ if destination is None:
+ destination = collections.OrderedDict()
+ destination._metadata = collections.OrderedDict()
+
+ for name, module in self._modules.items():
+ if module is not None:
+ module.state_dict(destination, prefix + name + ".", keep_vars=keep_vars)
+ if self.sample_rate < 1:
+ destination["weight"] = self.weight.detach()
+ else:
+ destination["weight"] = self.weight_activated.data.detach()
+ return destination
+
+ def load_state_dict(self, state_dict, strict: bool = True):
+ if self.sample_rate < 1:
+ self.weight = state_dict["weight"].to(self.weight.device)
+ self.weight_exp_avg.zero_()
+ self.weight_exp_avg_sq.zero_()
+ self.weight_activated.data.zero_()
+ self.weight_activated_exp_avg.zero_()
+ self.weight_activated_exp_avg_sq.zero_()
+ else:
+ self.weight_activated.data = state_dict["weight"].to(self.weight_activated.data.device)
+
+
+class DistCrossEntropyFunc(torch.autograd.Function):
+ """
+ CrossEntropy loss is calculated in parallel, allreduce denominator into single gpu and calculate softmax.
+ Implemented of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
+ """
+
+ @staticmethod
+ def forward(ctx, logits: torch.Tensor, label: torch.Tensor):
+ """ """
+ batch_size = logits.size(0)
+ # for numerical stability
+ max_logits, _ = torch.max(logits, dim=1, keepdim=True)
+ # local to global
+ distributed.all_reduce(max_logits, distributed.ReduceOp.MAX)
+ logits.sub_(max_logits)
+ logits.exp_()
+ sum_logits_exp = torch.sum(logits, dim=1, keepdim=True)
+ # local to global
+ distributed.all_reduce(sum_logits_exp, distributed.ReduceOp.SUM)
+ logits.div_(sum_logits_exp)
+ index = torch.where(label != -1)[0]
+ # loss
+ loss = torch.zeros(batch_size, 1, device=logits.device)
+ loss[index] = logits[index].gather(1, label[index])
+ distributed.all_reduce(loss, distributed.ReduceOp.SUM)
+ ctx.save_for_backward(index, logits, label)
+ return loss.clamp_min_(1e-30).log_().mean() * (-1)
+
+ @staticmethod
+ def backward(ctx, loss_gradient):
+ """
+ Args:
+ loss_grad (torch.Tensor): gradient backward by last layer
+ Returns:
+ gradients for each input in forward function
+ `None` gradients for one-hot label
+ """
+ (
+ index,
+ logits,
+ label,
+ ) = ctx.saved_tensors
+ batch_size = logits.size(0)
+ one_hot = torch.zeros(size=[index.size(0), logits.size(1)], device=logits.device)
+ one_hot.scatter_(1, label[index], 1)
+ logits[index] -= one_hot
+ logits.div_(batch_size)
+ return logits * loss_gradient.item(), None
+
+
+class DistCrossEntropy(torch.nn.Module):
+ def __init__(self):
+ super(DistCrossEntropy, self).__init__()
+
+ def forward(self, logit_part, label_part):
+ return DistCrossEntropyFunc.apply(logit_part, label_part)
+
+
+class AllGatherFunc(torch.autograd.Function):
+ """AllGather op with gradient backward"""
+
+ @staticmethod
+ def forward(ctx, tensor, *gather_list):
+ gather_list = list(gather_list)
+ distributed.all_gather(gather_list, tensor)
+ return tuple(gather_list)
+
+ @staticmethod
+ def backward(ctx, *grads):
+ grad_list = list(grads)
+ rank = distributed.get_rank()
+ grad_out = grad_list[rank]
+
+ dist_ops = [
+ distributed.reduce(grad_out, rank, distributed.ReduceOp.SUM, async_op=True)
+ if i == rank
+ else distributed.reduce(grad_list[i], i, distributed.ReduceOp.SUM, async_op=True)
+ for i in range(distributed.get_world_size())
+ ]
+ for _op in dist_ops:
+ _op.wait()
+
+ grad_out *= len(grad_list) # cooperate with distributed loss function
+ return (grad_out, *[None for _ in range(len(grad_list))])
+
+
+AllGather = AllGatherFunc.apply
diff --git a/arcface_torch/partial_fc_v2.py b/arcface_torch/partial_fc_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..45078e430a6b0cd442ff65618093689822711aef
--- /dev/null
+++ b/arcface_torch/partial_fc_v2.py
@@ -0,0 +1,247 @@
+import math
+from typing import Callable
+
+import torch
+from torch import distributed
+from torch.nn.functional import linear
+from torch.nn.functional import normalize
+
+
+class PartialFC_V2(torch.nn.Module):
+ """
+ https://arxiv.org/abs/2203.15565
+ A distributed sparsely updating variant of the FC layer, named Partial FC (PFC).
+ When sample rate less than 1, in each iteration, positive class centers and a random subset of
+ negative class centers are selected to compute the margin-based softmax loss, all class
+ centers are still maintained throughout the whole training process, but only a subset is
+ selected and updated in each iteration.
+ .. note::
+ When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1).
+ Example:
+ --------
+ >>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2)
+ >>> for img, labels in data_loader:
+ >>> embeddings = net(img)
+ >>> loss = module_pfc(embeddings, labels)
+ >>> loss.backward()
+ >>> optimizer.step()
+ """
+
+ _version = 2
+
+ def __init__(
+ self,
+ margin_loss: Callable,
+ embedding_size: int,
+ num_classes: int,
+ sample_rate: float = 1.0,
+ fp16: bool = False,
+ ):
+ """
+ Paramenters:
+ -----------
+ embedding_size: int
+ The dimension of embedding, required
+ num_classes: int
+ Total number of classes, required
+ sample_rate: float
+ The rate of negative centers participating in the calculation, default is 1.0.
+ """
+ super(PartialFC_V2, self).__init__()
+ assert distributed.is_initialized(), "must initialize distributed before create this"
+ self.rank = distributed.get_rank()
+ self.world_size = distributed.get_world_size()
+
+ self.dist_cross_entropy = DistCrossEntropy()
+ self.embedding_size = embedding_size
+ self.sample_rate: float = sample_rate
+ self.fp16 = fp16
+ self.num_local: int = num_classes // self.world_size + int(self.rank < num_classes % self.world_size)
+ self.class_start: int = num_classes // self.world_size * self.rank + min(
+ self.rank, num_classes % self.world_size
+ )
+ self.num_sample: int = int(self.sample_rate * self.num_local)
+ self.last_batch_size: int = 0
+
+ self.is_updated: bool = True
+ self.init_weight_update: bool = True
+ self.weight = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size)))
+
+ # margin_loss
+ if isinstance(margin_loss, Callable):
+ self.margin_softmax = margin_loss
+ else:
+ raise
+
+ def sample(self, labels, index_positive):
+ """
+ This functions will change the value of labels
+ Parameters:
+ -----------
+ labels: torch.Tensor
+ pass
+ index_positive: torch.Tensor
+ pass
+ optimizer: torch.optim.Optimizer
+ pass
+ """
+ with torch.no_grad():
+ positive = torch.unique(labels[index_positive], sorted=True).cuda()
+ if self.num_sample - positive.size(0) >= 0:
+ perm = torch.rand(size=[self.num_local]).cuda()
+ perm[positive] = 2.0
+ index = torch.topk(perm, k=self.num_sample)[1].cuda()
+ index = index.sort()[0].cuda()
+ else:
+ index = positive
+ self.weight_index = index
+
+ labels[index_positive] = torch.searchsorted(index, labels[index_positive])
+
+ return self.weight[self.weight_index]
+
+ def forward(
+ self,
+ local_embeddings: torch.Tensor,
+ local_labels: torch.Tensor,
+ ):
+ """
+ Parameters:
+ ----------
+ local_embeddings: torch.Tensor
+ feature embeddings on each GPU(Rank).
+ local_labels: torch.Tensor
+ labels on each GPU(Rank).
+ Returns:
+ -------
+ loss: torch.Tensor
+ pass
+ """
+ local_labels.squeeze_()
+ local_labels = local_labels.long()
+
+ batch_size = local_embeddings.size(0)
+ if self.last_batch_size == 0:
+ self.last_batch_size = batch_size
+ assert (
+ self.last_batch_size == batch_size
+ ), f"last batch size do not equal current batch size: {self.last_batch_size} vs {batch_size}"
+
+ _gather_embeddings = [torch.zeros((batch_size, self.embedding_size)).cuda() for _ in range(self.world_size)]
+ _gather_labels = [torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)]
+ _list_embeddings = AllGather(local_embeddings, *_gather_embeddings)
+ distributed.all_gather(_gather_labels, local_labels)
+
+ embeddings = torch.cat(_list_embeddings)
+ labels = torch.cat(_gather_labels)
+
+ labels = labels.view(-1, 1)
+ index_positive = (self.class_start <= labels) & (labels < self.class_start + self.num_local)
+ labels[~index_positive] = -1
+ labels[index_positive] -= self.class_start
+
+ if self.sample_rate < 1:
+ weight = self.sample(labels, index_positive)
+ else:
+ weight = self.weight
+
+ with torch.cuda.amp.autocast(self.fp16):
+ norm_embeddings = normalize(embeddings)
+ norm_weight_activated = normalize(weight)
+ logits = linear(norm_embeddings, norm_weight_activated)
+ if self.fp16:
+ logits = logits.float()
+ logits = logits.clamp(-1, 1)
+
+ logits = self.margin_softmax(logits, labels)
+ loss = self.dist_cross_entropy(logits, labels)
+ return loss
+
+
+class DistCrossEntropyFunc(torch.autograd.Function):
+ """
+ CrossEntropy loss is calculated in parallel, allreduce denominator into single gpu and calculate softmax.
+ Implemented of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
+ """
+
+ @staticmethod
+ def forward(ctx, logits: torch.Tensor, label: torch.Tensor):
+ """ """
+ batch_size = logits.size(0)
+ # for numerical stability
+ max_logits, _ = torch.max(logits, dim=1, keepdim=True)
+ # local to global
+ distributed.all_reduce(max_logits, distributed.ReduceOp.MAX)
+ logits.sub_(max_logits)
+ logits.exp_()
+ sum_logits_exp = torch.sum(logits, dim=1, keepdim=True)
+ # local to global
+ distributed.all_reduce(sum_logits_exp, distributed.ReduceOp.SUM)
+ logits.div_(sum_logits_exp)
+ index = torch.where(label != -1)[0]
+ # loss
+ loss = torch.zeros(batch_size, 1, device=logits.device)
+ loss[index] = logits[index].gather(1, label[index])
+ distributed.all_reduce(loss, distributed.ReduceOp.SUM)
+ ctx.save_for_backward(index, logits, label)
+ return loss.clamp_min_(1e-30).log_().mean() * (-1)
+
+ @staticmethod
+ def backward(ctx, loss_gradient):
+ """
+ Args:
+ loss_grad (torch.Tensor): gradient backward by last layer
+ Returns:
+ gradients for each input in forward function
+ `None` gradients for one-hot label
+ """
+ (
+ index,
+ logits,
+ label,
+ ) = ctx.saved_tensors
+ batch_size = logits.size(0)
+ one_hot = torch.zeros(size=[index.size(0), logits.size(1)], device=logits.device)
+ one_hot.scatter_(1, label[index], 1)
+ logits[index] -= one_hot
+ logits.div_(batch_size)
+ return logits * loss_gradient.item(), None
+
+
+class DistCrossEntropy(torch.nn.Module):
+ def __init__(self):
+ super(DistCrossEntropy, self).__init__()
+
+ def forward(self, logit_part, label_part):
+ return DistCrossEntropyFunc.apply(logit_part, label_part)
+
+
+class AllGatherFunc(torch.autograd.Function):
+ """AllGather op with gradient backward"""
+
+ @staticmethod
+ def forward(ctx, tensor, *gather_list):
+ gather_list = list(gather_list)
+ distributed.all_gather(gather_list, tensor)
+ return tuple(gather_list)
+
+ @staticmethod
+ def backward(ctx, *grads):
+ grad_list = list(grads)
+ rank = distributed.get_rank()
+ grad_out = grad_list[rank]
+
+ dist_ops = [
+ distributed.reduce(grad_out, rank, distributed.ReduceOp.SUM, async_op=True)
+ if i == rank
+ else distributed.reduce(grad_list[i], i, distributed.ReduceOp.SUM, async_op=True)
+ for i in range(distributed.get_world_size())
+ ]
+ for _op in dist_ops:
+ _op.wait()
+
+ grad_out *= len(grad_list) # cooperate with distributed loss function
+ return (grad_out, *[None for _ in range(len(grad_list))])
+
+
+AllGather = AllGatherFunc.apply
diff --git a/arcface_torch/requirement.txt b/arcface_torch/requirement.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f1a431ef9c39b258b676411f1081ed9006a8b817
--- /dev/null
+++ b/arcface_torch/requirement.txt
@@ -0,0 +1,6 @@
+tensorboard
+easydict
+mxnet
+onnx
+sklearn
+opencv-python
\ No newline at end of file
diff --git a/arcface_torch/run.sh b/arcface_torch/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6eacdf8e814d7bd68650c7eda8f72687ee74db16
--- /dev/null
+++ b/arcface_torch/run.sh
@@ -0,0 +1 @@
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 train_v2.py $@
diff --git a/arcface_torch/scripts/shuffle_rec.py b/arcface_torch/scripts/shuffle_rec.py
new file mode 100644
index 0000000000000000000000000000000000000000..1607fb2db48b9b32f4fa16c6ad97d15582820b2a
--- /dev/null
+++ b/arcface_torch/scripts/shuffle_rec.py
@@ -0,0 +1,81 @@
+import argparse
+import multiprocessing
+import os
+import time
+
+import mxnet as mx
+import numpy as np
+
+
+def read_worker(args, q_in):
+ path_imgidx = os.path.join(args.input, "train.idx")
+ path_imgrec = os.path.join(args.input, "train.rec")
+ imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, "r")
+
+ s = imgrec.read_idx(0)
+ header, _ = mx.recordio.unpack(s)
+ assert header.flag > 0
+
+ imgidx = np.array(range(1, int(header.label[0])))
+ np.random.shuffle(imgidx)
+
+ for idx in imgidx:
+ item = imgrec.read_idx(idx)
+ q_in.put(item)
+
+ q_in.put(None)
+ imgrec.close()
+
+
+def write_worker(args, q_out):
+ pre_time = time.time()
+
+ if args.input[-1] == "/":
+ args.input = args.input[:-1]
+ dirname = os.path.dirname(args.input)
+ basename = os.path.basename(args.input)
+ output = os.path.join(dirname, f"shuffled_{basename}")
+ os.makedirs(output, exist_ok=True)
+
+ path_imgidx = os.path.join(output, "train.idx")
+ path_imgrec = os.path.join(output, "train.rec")
+ save_record = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, "w")
+ more = True
+ count = 0
+ while more:
+ deq = q_out.get()
+ if deq is None:
+ more = False
+ else:
+ header, jpeg = mx.recordio.unpack(deq)
+ # TODO it is currently not fully developed
+ if isinstance(header.label, float):
+ label = header.label
+ else:
+ label = header.label[0]
+
+ header = mx.recordio.IRHeader(flag=header.flag, label=label, id=header.id, id2=header.id2)
+ save_record.write_idx(count, mx.recordio.pack(header, jpeg))
+ count += 1
+ if count % 10000 == 0:
+ cur_time = time.time()
+ print("save time:", cur_time - pre_time, " count:", count)
+ pre_time = cur_time
+ print(count)
+ save_record.close()
+
+
+def main(args):
+ queue = multiprocessing.Queue(10240)
+ read_process = multiprocessing.Process(target=read_worker, args=(args, queue))
+ read_process.daemon = True
+ read_process.start()
+ write_process = multiprocessing.Process(target=write_worker, args=(args, queue))
+ write_process.start()
+ write_process.join()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("input", help="path to source rec.")
+ main(parser.parse_args())
diff --git a/arcface_torch/torch2onnx.py b/arcface_torch/torch2onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..23c2bb9e85c9bc5dc0b90842ad9c782d5e7cde79
--- /dev/null
+++ b/arcface_torch/torch2onnx.py
@@ -0,0 +1,56 @@
+import numpy as np
+import onnx
+import torch
+
+
+def convert_onnx(net, path_module, output, opset=11, simplify=False):
+ assert isinstance(net, torch.nn.Module)
+ img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
+ img = img.astype(np.float)
+ img = (img / 255.0 - 0.5) / 0.5 # torch style norm
+ img = img.transpose((2, 0, 1))
+ img = torch.from_numpy(img).unsqueeze(0).float()
+
+ weight = torch.load(path_module)
+ net.load_state_dict(weight, strict=True)
+ net.eval()
+ torch.onnx.export(
+ net, img, output, input_names=["data"], keep_initializers_as_inputs=False, verbose=False, opset_version=opset
+ )
+ model = onnx.load(output)
+ graph = model.graph
+ graph.input[0].type.tensor_type.shape.dim[0].dim_param = "None"
+ if simplify:
+ from onnxsim import simplify
+
+ model, check = simplify(model)
+ assert check, "Simplified ONNX model could not be validated"
+ onnx.save(model, output)
+
+
+if __name__ == "__main__":
+ import os
+ import argparse
+ from backbones import get_model
+
+ parser = argparse.ArgumentParser(description="ArcFace PyTorch to onnx")
+ parser.add_argument("input", type=str, help="input backbone.pth file or path")
+ parser.add_argument("--output", type=str, default=None, help="output onnx path")
+ parser.add_argument("--network", type=str, default=None, help="backbone network")
+ parser.add_argument("--simplify", type=bool, default=False, help="onnx simplify")
+ args = parser.parse_args()
+ input_file = args.input
+ if os.path.isdir(input_file):
+ input_file = os.path.join(input_file, "model.pt")
+ assert os.path.exists(input_file)
+ # model_name = os.path.basename(os.path.dirname(input_file)).lower()
+ # params = model_name.split("_")
+ # if len(params) >= 3 and params[1] in ('arcface', 'cosface'):
+ # if args.network is None:
+ # args.network = params[2]
+ assert args.network is not None
+ print(args)
+ backbone_onnx = get_model(args.network, dropout=0.0, fp16=False, num_features=512)
+ if args.output is None:
+ args.output = os.path.join(os.path.dirname(args.input), "model.onnx")
+ convert_onnx(backbone_onnx, input_file, args.output, simplify=args.simplify)
diff --git a/arcface_torch/train.py b/arcface_torch/train.py
new file mode 100755
index 0000000000000000000000000000000000000000..3905bb0f90bb3806cd0698a322617a0d47c61390
--- /dev/null
+++ b/arcface_torch/train.py
@@ -0,0 +1,253 @@
+import argparse
+import logging
+import os
+from datetime import datetime
+
+import numpy as np
+import torch
+from backbones import get_model
+from dataset import get_dataloader
+from losses import CombinedMarginLoss
+from lr_scheduler import PolyScheduler
+from partial_fc import PartialFC
+from partial_fc import PartialFCAdamW
+from torch import distributed
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+from utils.utils_callbacks import CallBackLogging
+from utils.utils_callbacks import CallBackVerification
+from utils.utils_config import get_config
+from utils.utils_distributed_sampler import setup_seed
+from utils.utils_logging import AverageMeter
+from utils.utils_logging import init_logging
+
+assert (
+ torch.__version__ >= "1.12.0"
+), "In order to enjoy the features of the new torch, \
+we have upgraded the torch to 1.12.0. torch before than 1.12.0 may not work in the future."
+
+try:
+ rank = int(os.environ["RANK"])
+ local_rank = int(os.environ["LOCAL_RANK"])
+ world_size = int(os.environ["WORLD_SIZE"])
+ distributed.init_process_group("nccl")
+except KeyError:
+ rank = 0
+ local_rank = 0
+ world_size = 1
+ distributed.init_process_group(
+ backend="nccl",
+ init_method="tcp://127.0.0.1:12584",
+ rank=rank,
+ world_size=world_size,
+ )
+
+
+def main(args):
+
+ # get config
+ cfg = get_config(args.config)
+ # global control random seed
+ setup_seed(seed=cfg.seed, cuda_deterministic=False)
+
+ torch.cuda.set_device(local_rank)
+
+ os.makedirs(cfg.output, exist_ok=True)
+ init_logging(rank, cfg.output)
+
+ summary_writer = SummaryWriter(log_dir=os.path.join(cfg.output, "tensorboard")) if rank == 0 else None
+
+ wandb_logger = None
+ if cfg.using_wandb:
+ import wandb
+
+ # Sign in to wandb
+ try:
+ wandb.login(key=cfg.wandb_key)
+ except Exception as e:
+ print("WandB Key must be provided in config file (base.py).")
+ print(f"Config Error: {e}")
+ # Initialize wandb
+ run_name = datetime.now().strftime("%y%m%d_%H%M") + f"_GPU{rank}"
+ run_name = run_name if cfg.suffix_run_name is None else run_name + f"_{cfg.suffix_run_name}"
+ try:
+ wandb_logger = (
+ wandb.init(
+ entity=cfg.wandb_entity,
+ project=cfg.wandb_project,
+ sync_tensorboard=True,
+ resume=cfg.wandb_resume,
+ name=run_name,
+ notes=cfg.notes,
+ )
+ if rank == 0 or cfg.wandb_log_all
+ else None
+ )
+ if wandb_logger:
+ wandb_logger.config.update(cfg)
+ except Exception as e:
+ print("WandB Data (Entity and Project name) must be provided in config file (base.py).")
+ print(f"Config Error: {e}")
+
+ train_loader = get_dataloader(cfg.rec, local_rank, cfg.batch_size, cfg.dali, cfg.seed, cfg.num_workers)
+
+ backbone = get_model(cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).cuda()
+
+ backbone = torch.nn.parallel.DistributedDataParallel(
+ module=backbone, broadcast_buffers=False, device_ids=[local_rank], bucket_cap_mb=16, find_unused_parameters=True
+ )
+
+ backbone.train()
+ # FIXME using gradient checkpoint if there are some unused parameters will cause error
+ backbone._set_static_graph()
+
+ margin_loss = CombinedMarginLoss(
+ 64, cfg.margin_list[0], cfg.margin_list[1], cfg.margin_list[2], cfg.interclass_filtering_threshold
+ )
+
+ if cfg.optimizer == "sgd":
+ module_partial_fc = PartialFC(margin_loss, cfg.embedding_size, cfg.num_classes, cfg.sample_rate, cfg.fp16)
+ module_partial_fc.train().cuda()
+ # TODO the params of partial fc must be last in the params list
+ opt = torch.optim.SGD(
+ params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}],
+ lr=cfg.lr,
+ momentum=0.9,
+ weight_decay=cfg.weight_decay,
+ )
+
+ elif cfg.optimizer == "adamw":
+ module_partial_fc = PartialFCAdamW(margin_loss, cfg.embedding_size, cfg.num_classes, cfg.sample_rate, cfg.fp16)
+ module_partial_fc.train().cuda()
+ opt = torch.optim.AdamW(
+ params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}],
+ lr=cfg.lr,
+ weight_decay=cfg.weight_decay,
+ )
+ else:
+ raise
+
+ cfg.total_batch_size = cfg.batch_size * world_size
+ cfg.warmup_step = cfg.num_image // cfg.total_batch_size * cfg.warmup_epoch
+ cfg.total_step = cfg.num_image // cfg.total_batch_size * cfg.num_epoch
+
+ lr_scheduler = PolyScheduler(
+ optimizer=opt, base_lr=cfg.lr, max_steps=cfg.total_step, warmup_steps=cfg.warmup_step, last_epoch=-1
+ )
+
+ start_epoch = 0
+ global_step = 0
+ if cfg.resume:
+ dict_checkpoint = torch.load(os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt"))
+ start_epoch = dict_checkpoint["epoch"]
+ global_step = dict_checkpoint["global_step"]
+ backbone.module.load_state_dict(dict_checkpoint["state_dict_backbone"])
+ module_partial_fc.load_state_dict(dict_checkpoint["state_dict_softmax_fc"])
+ opt.load_state_dict(dict_checkpoint["state_optimizer"])
+ lr_scheduler.load_state_dict(dict_checkpoint["state_lr_scheduler"])
+ del dict_checkpoint
+
+ for key, value in cfg.items():
+ num_space = 25 - len(key)
+ logging.info(": " + key + " " * num_space + str(value))
+
+ callback_verification = CallBackVerification(
+ val_targets=cfg.val_targets, rec_prefix=cfg.rec, summary_writer=summary_writer, wandb_logger=wandb_logger
+ )
+ callback_logging = CallBackLogging(
+ frequent=cfg.frequent,
+ total_step=cfg.total_step,
+ batch_size=cfg.batch_size,
+ start_step=global_step,
+ writer=summary_writer,
+ )
+
+ loss_am = AverageMeter()
+ amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100)
+
+ for epoch in range(start_epoch, cfg.num_epoch):
+
+ if isinstance(train_loader, DataLoader):
+ train_loader.sampler.set_epoch(epoch)
+ for _, (img, local_labels) in enumerate(train_loader):
+ global_step += 1
+ local_embeddings = backbone(img)
+ loss: torch.Tensor = module_partial_fc(local_embeddings, local_labels, opt)
+
+ if cfg.fp16:
+ amp.scale(loss).backward()
+ amp.unscale_(opt)
+ torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
+ amp.step(opt)
+ amp.update()
+ else:
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
+ opt.step()
+
+ opt.zero_grad()
+ lr_scheduler.step()
+
+ with torch.no_grad():
+ if wandb_logger:
+ wandb_logger.log(
+ {
+ "Loss/Step Loss": loss.item(),
+ "Loss/Train Loss": loss_am.avg,
+ "Process/Step": global_step,
+ "Process/Epoch": epoch,
+ }
+ )
+
+ loss_am.update(loss.item(), 1)
+ callback_logging(global_step, loss_am, epoch, cfg.fp16, lr_scheduler.get_last_lr()[0], amp)
+
+ if global_step % cfg.verbose == 0 and global_step > 0:
+ callback_verification(global_step, backbone)
+
+ if cfg.save_all_states:
+ checkpoint = {
+ "epoch": epoch + 1,
+ "global_step": global_step,
+ "state_dict_backbone": backbone.module.state_dict(),
+ "state_dict_softmax_fc": module_partial_fc.state_dict(),
+ "state_optimizer": opt.state_dict(),
+ "state_lr_scheduler": lr_scheduler.state_dict(),
+ }
+ torch.save(checkpoint, os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt"))
+
+ if rank == 0:
+ path_module = os.path.join(cfg.output, "model.pt")
+ torch.save(backbone.module.state_dict(), path_module)
+
+ if wandb_logger and cfg.save_artifacts:
+ artifact_name = f"{run_name}_E{epoch}"
+ model = wandb.Artifact(artifact_name, type="model")
+ model.add_file(path_module)
+ wandb_logger.log_artifact(model)
+
+ if cfg.dali:
+ train_loader.reset()
+
+ if rank == 0:
+ path_module = os.path.join(cfg.output, "model.pt")
+ torch.save(backbone.module.state_dict(), path_module)
+
+ from torch2onnx import convert_onnx
+
+ convert_onnx(backbone.module.cpu().eval(), path_module, os.path.join(cfg.output, "model.onnx"))
+
+ if wandb_logger and cfg.save_artifacts:
+ artifact_name = f"{run_name}_Final"
+ model = wandb.Artifact(artifact_name, type="model")
+ model.add_file(path_module)
+ wandb_logger.log_artifact(model)
+
+ distributed.destroy_process_group()
+
+
+if __name__ == "__main__":
+ torch.backends.cudnn.benchmark = True
+ parser = argparse.ArgumentParser(description="Distributed Arcface Training in Pytorch")
+ parser.add_argument("config", type=str, help="py config file")
+ main(parser.parse_args())
diff --git a/arcface_torch/train_v2.py b/arcface_torch/train_v2.py
new file mode 100755
index 0000000000000000000000000000000000000000..ba3c15e6a1615f28daaab1ad225f7b61b27bdffc
--- /dev/null
+++ b/arcface_torch/train_v2.py
@@ -0,0 +1,248 @@
+import argparse
+import logging
+import os
+from datetime import datetime
+
+import numpy as np
+import torch
+from backbones import get_model
+from dataset import get_dataloader
+from losses import CombinedMarginLoss
+from lr_scheduler import PolyScheduler
+from partial_fc_v2 import PartialFC_V2
+from torch import distributed
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+from utils.utils_callbacks import CallBackLogging
+from utils.utils_callbacks import CallBackVerification
+from utils.utils_config import get_config
+from utils.utils_distributed_sampler import setup_seed
+from utils.utils_logging import AverageMeter
+from utils.utils_logging import init_logging
+
+assert (
+ torch.__version__ >= "1.12.0"
+), "In order to enjoy the features of the new torch, \
+we have upgraded the torch to 1.12.0. torch before than 1.12.0 may not work in the future."
+
+try:
+ rank = int(os.environ["RANK"])
+ local_rank = int(os.environ["LOCAL_RANK"])
+ world_size = int(os.environ["WORLD_SIZE"])
+ distributed.init_process_group("nccl")
+except KeyError:
+ rank = 0
+ local_rank = 0
+ world_size = 1
+ distributed.init_process_group(
+ backend="nccl",
+ init_method="tcp://127.0.0.1:12584",
+ rank=rank,
+ world_size=world_size,
+ )
+
+
+def main(args):
+
+ # get config
+ cfg = get_config(args.config)
+ # global control random seed
+ setup_seed(seed=cfg.seed, cuda_deterministic=False)
+
+ torch.cuda.set_device(local_rank)
+
+ os.makedirs(cfg.output, exist_ok=True)
+ init_logging(rank, cfg.output)
+
+ summary_writer = SummaryWriter(log_dir=os.path.join(cfg.output, "tensorboard")) if rank == 0 else None
+
+ wandb_logger = None
+ if cfg.using_wandb:
+ import wandb
+
+ # Sign in to wandb
+ try:
+ wandb.login(key=cfg.wandb_key)
+ except Exception as e:
+ print("WandB Key must be provided in config file (base.py).")
+ print(f"Config Error: {e}")
+ # Initialize wandb
+ run_name = datetime.now().strftime("%y%m%d_%H%M") + f"_GPU{rank}"
+ run_name = run_name if cfg.suffix_run_name is None else run_name + f"_{cfg.suffix_run_name}"
+ try:
+ wandb_logger = (
+ wandb.init(
+ entity=cfg.wandb_entity,
+ project=cfg.wandb_project,
+ sync_tensorboard=True,
+ resume=cfg.wandb_resume,
+ name=run_name,
+ notes=cfg.notes,
+ )
+ if rank == 0 or cfg.wandb_log_all
+ else None
+ )
+ if wandb_logger:
+ wandb_logger.config.update(cfg)
+ except Exception as e:
+ print("WandB Data (Entity and Project name) must be provided in config file (base.py).")
+ print(f"Config Error: {e}")
+
+ train_loader = get_dataloader(cfg.rec, local_rank, cfg.batch_size, cfg.dali, cfg.seed, cfg.num_workers)
+
+ backbone = get_model(cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).cuda()
+
+ backbone = torch.nn.parallel.DistributedDataParallel(
+ module=backbone, broadcast_buffers=False, device_ids=[local_rank], bucket_cap_mb=16, find_unused_parameters=True
+ )
+
+ backbone.train()
+ # FIXME using gradient checkpoint if there are some unused parameters will cause error
+ backbone._set_static_graph()
+
+ margin_loss = CombinedMarginLoss(
+ 64, cfg.margin_list[0], cfg.margin_list[1], cfg.margin_list[2], cfg.interclass_filtering_threshold
+ )
+
+ if cfg.optimizer == "sgd":
+ module_partial_fc = PartialFC_V2(margin_loss, cfg.embedding_size, cfg.num_classes, cfg.sample_rate, cfg.fp16)
+ module_partial_fc.train().cuda()
+ # TODO the params of partial fc must be last in the params list
+ opt = torch.optim.SGD(
+ params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}],
+ lr=cfg.lr,
+ momentum=0.9,
+ weight_decay=cfg.weight_decay,
+ )
+
+ elif cfg.optimizer == "adamw":
+ module_partial_fc = PartialFC_V2(margin_loss, cfg.embedding_size, cfg.num_classes, cfg.sample_rate, cfg.fp16)
+ module_partial_fc.train().cuda()
+ opt = torch.optim.AdamW(
+ params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}],
+ lr=cfg.lr,
+ weight_decay=cfg.weight_decay,
+ )
+ else:
+ raise
+
+ cfg.total_batch_size = cfg.batch_size * world_size
+ cfg.warmup_step = cfg.num_image // cfg.total_batch_size * cfg.warmup_epoch
+ cfg.total_step = cfg.num_image // cfg.total_batch_size * cfg.num_epoch
+
+ lr_scheduler = PolyScheduler(
+ optimizer=opt, base_lr=cfg.lr, max_steps=cfg.total_step, warmup_steps=cfg.warmup_step, last_epoch=-1
+ )
+
+ start_epoch = 0
+ global_step = 0
+ if cfg.resume:
+ dict_checkpoint = torch.load(os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt"))
+ start_epoch = dict_checkpoint["epoch"]
+ global_step = dict_checkpoint["global_step"]
+ backbone.module.load_state_dict(dict_checkpoint["state_dict_backbone"])
+ module_partial_fc.load_state_dict(dict_checkpoint["state_dict_softmax_fc"])
+ opt.load_state_dict(dict_checkpoint["state_optimizer"])
+ lr_scheduler.load_state_dict(dict_checkpoint["state_lr_scheduler"])
+ del dict_checkpoint
+
+ for key, value in cfg.items():
+ num_space = 25 - len(key)
+ logging.info(": " + key + " " * num_space + str(value))
+
+ callback_verification = CallBackVerification(
+ val_targets=cfg.val_targets, rec_prefix=cfg.rec, summary_writer=summary_writer, wandb_logger=wandb_logger
+ )
+ callback_logging = CallBackLogging(
+ frequent=cfg.frequent,
+ total_step=cfg.total_step,
+ batch_size=cfg.batch_size,
+ start_step=global_step,
+ writer=summary_writer,
+ )
+
+ loss_am = AverageMeter()
+ amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100)
+
+ for epoch in range(start_epoch, cfg.num_epoch):
+
+ if isinstance(train_loader, DataLoader):
+ train_loader.sampler.set_epoch(epoch)
+ for _, (img, local_labels) in enumerate(train_loader):
+ global_step += 1
+ local_embeddings = backbone(img)
+ loss: torch.Tensor = module_partial_fc(local_embeddings, local_labels)
+
+ if cfg.fp16:
+ amp.scale(loss).backward()
+ if global_step % cfg.gradient_acc == 0:
+ amp.unscale_(opt)
+ torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
+ amp.step(opt)
+ amp.update()
+ opt.zero_grad()
+ else:
+ loss.backward()
+ if global_step % cfg.gradient_acc == 0:
+ torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
+ opt.step()
+ opt.zero_grad()
+ lr_scheduler.step()
+
+ with torch.no_grad():
+ if wandb_logger:
+ wandb_logger.log(
+ {
+ "Loss/Step Loss": loss.item(),
+ "Loss/Train Loss": loss_am.avg,
+ "Process/Step": global_step,
+ "Process/Epoch": epoch,
+ }
+ )
+
+ loss_am.update(loss.item(), 1)
+ callback_logging(global_step, loss_am, epoch, cfg.fp16, lr_scheduler.get_last_lr()[0], amp)
+
+ if global_step % cfg.verbose == 0 and global_step > 0:
+ callback_verification(global_step, backbone)
+
+ if cfg.save_all_states:
+ checkpoint = {
+ "epoch": epoch + 1,
+ "global_step": global_step,
+ "state_dict_backbone": backbone.module.state_dict(),
+ "state_dict_softmax_fc": module_partial_fc.state_dict(),
+ "state_optimizer": opt.state_dict(),
+ "state_lr_scheduler": lr_scheduler.state_dict(),
+ }
+ torch.save(checkpoint, os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt"))
+
+ if rank == 0:
+ path_module = os.path.join(cfg.output, "model.pt")
+ torch.save(backbone.module.state_dict(), path_module)
+
+ if wandb_logger and cfg.save_artifacts:
+ artifact_name = f"{run_name}_E{epoch}"
+ model = wandb.Artifact(artifact_name, type="model")
+ model.add_file(path_module)
+ wandb_logger.log_artifact(model)
+
+ if cfg.dali:
+ train_loader.reset()
+
+ if rank == 0:
+ path_module = os.path.join(cfg.output, "model.pt")
+ torch.save(backbone.module.state_dict(), path_module)
+
+ if wandb_logger and cfg.save_artifacts:
+ artifact_name = f"{run_name}_Final"
+ model = wandb.Artifact(artifact_name, type="model")
+ model.add_file(path_module)
+ wandb_logger.log_artifact(model)
+
+
+if __name__ == "__main__":
+ torch.backends.cudnn.benchmark = True
+ parser = argparse.ArgumentParser(description="Distributed Arcface Training in Pytorch")
+ parser.add_argument("config", type=str, help="py config file")
+ main(parser.parse_args())
diff --git a/arcface_torch/utils/__init__.py b/arcface_torch/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/arcface_torch/utils/plot.py b/arcface_torch/utils/plot.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e1429e77d67c32ce9f6c4495c75608941bbcebc
--- /dev/null
+++ b/arcface_torch/utils/plot.py
@@ -0,0 +1,65 @@
+import os
+import sys
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
+from prettytable import PrettyTable
+from sklearn.metrics import auc
+from sklearn.metrics import roc_curve
+
+with open(sys.argv[1], "r") as f:
+ files = f.readlines()
+
+files = [x.strip() for x in files]
+image_path = "/train_tmp/IJB_release/IJBC"
+
+
+def read_template_pair_list(path):
+ pairs = pd.read_csv(path, sep=" ", header=None).values
+ t1 = pairs[:, 0].astype(np.int)
+ t2 = pairs[:, 1].astype(np.int)
+ label = pairs[:, 2].astype(np.int)
+ return t1, t2, label
+
+
+p1, p2, label = read_template_pair_list(os.path.join("%s/meta" % image_path, "%s_template_pair_label.txt" % "ijbc"))
+
+methods = []
+scores = []
+for file in files:
+ methods.append(file)
+ scores.append(np.load(file))
+
+methods = np.array(methods)
+scores = dict(zip(methods, scores))
+colours = dict(zip(methods, sample_colours_from_colourmap(methods.shape[0], "Set2")))
+x_labels = [10**-6, 10**-5, 10**-4, 10**-3, 10**-2, 10**-1]
+tpr_fpr_table = PrettyTable(["Methods"] + [str(x) for x in x_labels])
+fig = plt.figure()
+for method in methods:
+ fpr, tpr, _ = roc_curve(label, scores[method])
+ roc_auc = auc(fpr, tpr)
+ fpr = np.flipud(fpr)
+ tpr = np.flipud(tpr) # select largest tpr at same fpr
+ plt.plot(
+ fpr, tpr, color=colours[method], lw=1, label=("[%s (AUC = %0.4f %%)]" % (method.split("-")[-1], roc_auc * 100))
+ )
+ tpr_fpr_row = []
+ tpr_fpr_row.append(method)
+ for fpr_iter in np.arange(len(x_labels)):
+ _, min_index = min(list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
+ tpr_fpr_row.append("%.2f" % (tpr[min_index] * 100))
+ tpr_fpr_table.add_row(tpr_fpr_row)
+plt.xlim([10**-6, 0.1])
+plt.ylim([0.3, 1.0])
+plt.grid(linestyle="--", linewidth=1)
+plt.xticks(x_labels)
+plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
+plt.xscale("log")
+plt.xlabel("False Positive Rate")
+plt.ylabel("True Positive Rate")
+plt.title("ROC on IJB")
+plt.legend(loc="lower right")
+print(tpr_fpr_table)
diff --git a/arcface_torch/utils/utils_callbacks.py b/arcface_torch/utils/utils_callbacks.py
new file mode 100755
index 0000000000000000000000000000000000000000..6afa461dd3a163628f71499da66fc032b272e969
--- /dev/null
+++ b/arcface_torch/utils/utils_callbacks.py
@@ -0,0 +1,141 @@
+import logging
+import os
+import time
+from typing import List
+
+import torch
+from eval import verification
+from torch import distributed
+from torch.utils.tensorboard import SummaryWriter
+from utils.utils_logging import AverageMeter
+
+
+class CallBackVerification(object):
+ def __init__(self, val_targets, rec_prefix, summary_writer=None, image_size=(112, 112), wandb_logger=None):
+ self.rank: int = distributed.get_rank()
+ self.highest_acc: float = 0.0
+ self.highest_acc_list: List[float] = [0.0] * len(val_targets)
+ self.ver_list: List[object] = []
+ self.ver_name_list: List[str] = []
+ if self.rank is 0:
+ self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size)
+
+ self.summary_writer = summary_writer
+ self.wandb_logger = wandb_logger
+
+ def ver_test(self, backbone: torch.nn.Module, global_step: int):
+ results = []
+ for i in range(len(self.ver_list)):
+ acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(self.ver_list[i], backbone, 10, 10)
+ logging.info("[%s][%d]XNorm: %f" % (self.ver_name_list[i], global_step, xnorm))
+ logging.info("[%s][%d]Accuracy-Flip: %1.5f+-%1.5f" % (self.ver_name_list[i], global_step, acc2, std2))
+
+ self.summary_writer: SummaryWriter
+ self.summary_writer.add_scalar(
+ tag=self.ver_name_list[i],
+ scalar_value=acc2,
+ global_step=global_step,
+ )
+ if self.wandb_logger:
+ import wandb
+
+ self.wandb_logger.log(
+ {
+ f"Acc/val-Acc1 {self.ver_name_list[i]}": acc1,
+ f"Acc/val-Acc2 {self.ver_name_list[i]}": acc2,
+ # f'Acc/val-std1 {self.ver_name_list[i]}': std1,
+ # f'Acc/val-std2 {self.ver_name_list[i]}': acc2,
+ }
+ )
+
+ if acc2 > self.highest_acc_list[i]:
+ self.highest_acc_list[i] = acc2
+ logging.info(
+ "[%s][%d]Accuracy-Highest: %1.5f" % (self.ver_name_list[i], global_step, self.highest_acc_list[i])
+ )
+ results.append(acc2)
+
+ def init_dataset(self, val_targets, data_dir, image_size):
+ for name in val_targets:
+ path = os.path.join(data_dir, name + ".bin")
+ if os.path.exists(path):
+ data_set = verification.load_bin(path, image_size)
+ self.ver_list.append(data_set)
+ self.ver_name_list.append(name)
+
+ def __call__(self, num_update, backbone: torch.nn.Module):
+ if self.rank is 0 and num_update > 0:
+ backbone.eval()
+ self.ver_test(backbone, num_update)
+ backbone.train()
+
+
+class CallBackLogging(object):
+ def __init__(self, frequent, total_step, batch_size, start_step=0, writer=None):
+ self.frequent: int = frequent
+ self.rank: int = distributed.get_rank()
+ self.world_size: int = distributed.get_world_size()
+ self.time_start = time.time()
+ self.total_step: int = total_step
+ self.start_step: int = start_step
+ self.batch_size: int = batch_size
+ self.writer = writer
+
+ self.init = False
+ self.tic = 0
+
+ def __call__(
+ self,
+ global_step: int,
+ loss: AverageMeter,
+ epoch: int,
+ fp16: bool,
+ learning_rate: float,
+ grad_scaler: torch.cuda.amp.GradScaler,
+ ):
+ if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0:
+ if self.init:
+ try:
+ speed: float = self.frequent * self.batch_size / (time.time() - self.tic)
+ speed_total = speed * self.world_size
+ except ZeroDivisionError:
+ speed_total = float("inf")
+
+ # time_now = (time.time() - self.time_start) / 3600
+ # time_total = time_now / ((global_step + 1) / self.total_step)
+ # time_for_end = time_total - time_now
+ time_now = time.time()
+ time_sec = int(time_now - self.time_start)
+ time_sec_avg = time_sec / (global_step - self.start_step + 1)
+ eta_sec = time_sec_avg * (self.total_step - global_step - 1)
+ time_for_end = eta_sec / 3600
+ if self.writer is not None:
+ self.writer.add_scalar("time_for_end", time_for_end, global_step)
+ self.writer.add_scalar("learning_rate", learning_rate, global_step)
+ self.writer.add_scalar("loss", loss.avg, global_step)
+ if fp16:
+ msg = (
+ "Speed %.2f samples/sec Loss %.4f LearningRate %.6f Epoch: %d Global Step: %d "
+ "Fp16 Grad Scale: %2.f Required: %1.f hours"
+ % (
+ speed_total,
+ loss.avg,
+ learning_rate,
+ epoch,
+ global_step,
+ grad_scaler.get_scale(),
+ time_for_end,
+ )
+ )
+ else:
+ msg = (
+ "Speed %.2f samples/sec Loss %.4f LearningRate %.6f Epoch: %d Global Step: %d "
+ "Required: %1.f hours"
+ % (speed_total, loss.avg, learning_rate, epoch, global_step, time_for_end)
+ )
+ logging.info(msg)
+ loss.reset()
+ self.tic = time.time()
+ else:
+ self.init = True
+ self.tic = time.time()
diff --git a/arcface_torch/utils/utils_config.py b/arcface_torch/utils/utils_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..140625ccfbc1b4b8d71470f50da7d4f88803cf11
--- /dev/null
+++ b/arcface_torch/utils/utils_config.py
@@ -0,0 +1,16 @@
+import importlib
+import os.path as osp
+
+
+def get_config(config_file):
+ assert config_file.startswith("configs/"), "config file setting must start with configs/"
+ temp_config_name = osp.basename(config_file)
+ temp_module_name = osp.splitext(temp_config_name)[0]
+ config = importlib.import_module("configs.base")
+ cfg = config.config
+ config = importlib.import_module("configs.%s" % temp_module_name)
+ job_cfg = config.config
+ cfg.update(job_cfg)
+ if cfg.output is None:
+ cfg.output = osp.join("work_dirs", temp_module_name)
+ return cfg
diff --git a/arcface_torch/utils/utils_distributed_sampler.py b/arcface_torch/utils/utils_distributed_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7e57275fa17a0a9dbf27fd0eb941dd0fec1823f
--- /dev/null
+++ b/arcface_torch/utils/utils_distributed_sampler.py
@@ -0,0 +1,124 @@
+import math
+import os
+import random
+
+import numpy as np
+import torch
+import torch.distributed as dist
+from torch.utils.data import DistributedSampler as _DistributedSampler
+
+
+def setup_seed(seed, cuda_deterministic=True):
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ os.environ["PYTHONHASHSEED"] = str(seed)
+ if cuda_deterministic: # slower, more reproducible
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ else: # faster, less reproducible
+ torch.backends.cudnn.deterministic = False
+ torch.backends.cudnn.benchmark = True
+
+
+def worker_init_fn(worker_id, num_workers, rank, seed):
+ # The seed of each worker equals to
+ # num_worker * rank + worker_id + user_seed
+ worker_seed = num_workers * rank + worker_id + seed
+ np.random.seed(worker_seed)
+ random.seed(worker_seed)
+ torch.manual_seed(worker_seed)
+
+
+def get_dist_info():
+ if dist.is_available() and dist.is_initialized():
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ else:
+ rank = 0
+ world_size = 1
+
+ return rank, world_size
+
+
+def sync_random_seed(seed=None, device="cuda"):
+ """Make sure different ranks share the same seed.
+ All workers must call this function, otherwise it will deadlock.
+ This method is generally used in `DistributedSampler`,
+ because the seed should be identical across all processes
+ in the distributed group.
+ In distributed sampling, different ranks should sample non-overlapped
+ data in the dataset. Therefore, this function is used to make sure that
+ each rank shuffles the data indices in the same order based
+ on the same seed. Then different ranks could use different indices
+ to select non-overlapped data from the same data list.
+ Args:
+ seed (int, Optional): The seed. Default to None.
+ device (str): The device where the seed will be put on.
+ Default to 'cuda'.
+ Returns:
+ int: Seed to be used.
+ """
+ if seed is None:
+ seed = np.random.randint(2**31)
+ assert isinstance(seed, int)
+
+ rank, world_size = get_dist_info()
+
+ if world_size == 1:
+ return seed
+
+ if rank == 0:
+ random_num = torch.tensor(seed, dtype=torch.int32, device=device)
+ else:
+ random_num = torch.tensor(0, dtype=torch.int32, device=device)
+
+ dist.broadcast(random_num, src=0)
+
+ return random_num.item()
+
+
+class DistributedSampler(_DistributedSampler):
+ def __init__(
+ self,
+ dataset,
+ num_replicas=None, # world_size
+ rank=None, # local_rank
+ shuffle=True,
+ seed=0,
+ ):
+
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
+
+ # In distributed sampling, different ranks should sample
+ # non-overlapped data in the dataset. Therefore, this function
+ # is used to make sure that each rank shuffles the data indices
+ # in the same order based on the same seed. Then different ranks
+ # could use different indices to select non-overlapped data from the
+ # same data list.
+ self.seed = sync_random_seed(seed)
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ if self.shuffle:
+ g = torch.Generator()
+ # When :attr:`shuffle=True`, this ensures all replicas
+ # use a different random ordering for each epoch.
+ # Otherwise, the next iteration of this sampler will
+ # yield the same ordering.
+ g.manual_seed(self.epoch + self.seed)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
+ else:
+ indices = torch.arange(len(self.dataset)).tolist()
+
+ # add extra samples to make it evenly divisible
+ # in case that indices is shorter than half of total_size
+ indices = (indices * math.ceil(self.total_size / len(indices)))[: self.total_size]
+ assert len(indices) == self.total_size
+
+ # subsample
+ indices = indices[self.rank : self.total_size : self.num_replicas]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
diff --git a/arcface_torch/utils/utils_logging.py b/arcface_torch/utils/utils_logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..823771b7d7c45fd30fe7d5284cb52ee6ad17c834
--- /dev/null
+++ b/arcface_torch/utils/utils_logging.py
@@ -0,0 +1,40 @@
+import logging
+import os
+import sys
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+
+ def __init__(self):
+ self.val = None
+ self.avg = None
+ self.sum = None
+ self.count = None
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def init_logging(rank, models_root):
+ if rank == 0:
+ log_root = logging.getLogger()
+ log_root.setLevel(logging.INFO)
+ formatter = logging.Formatter("Training: %(asctime)s-%(message)s")
+ handler_file = logging.FileHandler(os.path.join(models_root, "training.log"))
+ handler_stream = logging.StreamHandler(sys.stdout)
+ handler_file.setFormatter(formatter)
+ handler_stream.setFormatter(formatter)
+ log_root.addHandler(handler_file)
+ log_root.addHandler(handler_stream)
+ log_root.info("rank_id: %d" % rank)
diff --git a/benchmark/__pycache__/face_pipeline.cpython-310.pyc b/benchmark/__pycache__/face_pipeline.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f4084f2c2853ce766e6bdda11fc21cfd01fb9dab
Binary files /dev/null and b/benchmark/__pycache__/face_pipeline.cpython-310.pyc differ
diff --git a/benchmark/__pycache__/inference_v2v.cpython-310.pyc b/benchmark/__pycache__/inference_v2v.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d2299d2cf1dc91c008341ac7dc4e33d6f1817f2c
Binary files /dev/null and b/benchmark/__pycache__/inference_v2v.cpython-310.pyc differ
diff --git a/benchmark/__pycache__/scrfd_detect.cpython-310.pyc b/benchmark/__pycache__/scrfd_detect.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d042fdf11fccbaee3c70b77eeb0c22dcd671501
Binary files /dev/null and b/benchmark/__pycache__/scrfd_detect.cpython-310.pyc differ
diff --git a/benchmark/app_image.py b/benchmark/app_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..b35f06ab23e0d7e47fe0b7fbff3a136f2152d3ef
--- /dev/null
+++ b/benchmark/app_image.py
@@ -0,0 +1,166 @@
+import argparse
+import os
+
+import cv2
+import gradio as gr
+import kornia
+import numpy as np
+import torch
+from loguru import logger
+
+from benchmark.face_pipeline import alignFace
+from benchmark.face_pipeline import FaceDetector
+from benchmark.face_pipeline import inverse_transform_batch
+from benchmark.face_pipeline import SoftErosion
+from configs.train_config import TrainConfig
+from models.model import HifiFace
+
+
+class ImageSwap:
+ def __init__(self, cfg, model=None):
+ self.device = cfg.device
+ self.facedetector = FaceDetector(cfg.face_detector_weights, device=self.device)
+ self.alignface = alignFace()
+
+ opt = TrainConfig()
+ opt.use_ddp = False
+
+ checkpoint = (cfg.model_path, cfg.model_idx)
+ if model is None:
+ self.model = HifiFace(
+ opt.identity_extractor_config, is_training=False, device=self.device, load_checkpoint=checkpoint
+ )
+ else:
+ self.model = model
+ self.model.eval()
+
+ self.smooth_mask = SoftErosion(kernel_size=7, threshold=0.9, iterations=7).to(self.device)
+
+ def _geometry_transfrom_warp_affine(self, swapped_image, inv_att_transforms, frame_size, square_mask):
+ swapped_image = kornia.geometry.transform.warp_affine(
+ swapped_image,
+ inv_att_transforms,
+ frame_size,
+ mode="bilinear",
+ padding_mode="border",
+ align_corners=True,
+ fill_value=torch.zeros(3),
+ )
+
+ square_mask = kornia.geometry.transform.warp_affine(
+ square_mask,
+ inv_att_transforms,
+ frame_size,
+ mode="bilinear",
+ padding_mode="zeros",
+ align_corners=True,
+ fill_value=torch.zeros(3),
+ )
+ return swapped_image, square_mask
+
+ def detect_and_align(self, image):
+ detection = self.facedetector(image)
+ if detection.score is None:
+ self.kps_window = []
+ return None, None
+ max_score_ind = np.argmax(detection.score, axis=0)
+ kps = detection.key_points[max_score_ind]
+ align_img, warp_mat = self.alignface.align_face(image, kps, 256)
+ align_img = cv2.resize(align_img, (256, 256))
+ align_img = align_img.transpose(2, 0, 1)
+ align_img = torch.from_numpy(align_img).unsqueeze(0).to(self.device).float()
+ align_img = align_img / 255.0
+ return align_img, warp_mat
+
+ def inference(self, source_face, target_face, shape_rate, id_rate, iterations=1):
+ src = source_face
+ src, _ = self.detect_and_align(src)
+ if src is None:
+ print("no face in src_img")
+ return
+ target = target_face
+ align_target, warp_mat = self.detect_and_align(target)
+ if align_target is None:
+ print("no face in target_img")
+ return
+ logger.info("start swapping")
+ frame_size = (target.shape[0], target.shape[1])
+ with torch.no_grad():
+ for _ in range(iterations):
+ swapped_face, m_r = self.model.forward(src, align_target, shape_rate, id_rate)
+ swapped_face = torch.clamp(swapped_face, 0, 1)
+ align_target = swapped_face
+ smooth_face_mask, _ = self.smooth_mask(m_r)
+ warp_mat = torch.from_numpy(warp_mat).float().unsqueeze(0)
+ inverse_warp_mat = inverse_transform_batch(warp_mat, device=self.device)
+ swapped_face, smooth_face_mask = self._geometry_transfrom_warp_affine(
+ swapped_face, inverse_warp_mat, frame_size, smooth_face_mask
+ )
+ target = torch.from_numpy(target.transpose(2, 0, 1)).unsqueeze(0).to(self.device).float() / 255.0
+ result_face = (1 - smooth_face_mask) * target + smooth_face_mask * swapped_face
+ result_face = torch.clamp(result_face * 255.0, 0.0, 255.0, out=None).type(dtype=torch.uint8)
+ result_face = result_face.detach().cpu().numpy()
+ img = result_face.transpose(0, 2, 3, 1)[0]
+
+ return img
+
+
+class ConfigPath:
+ face_detector_weights = "/data/useful_ckpt/face_detector/face_detector_scrfd_10g_bnkps.onnx"
+ model_path = ""
+ model_idx = 80000
+ device = "cuda"
+
+
+def main():
+ cfg = ConfigPath()
+ parser = argparse.ArgumentParser(
+ prog="benchmark", description="What the program does", epilog="Text at the bottom of help"
+ )
+ parser.add_argument("-m", "--model_path")
+ parser.add_argument("-i", "--model_idx")
+ parser.add_argument("-d", "--device", default="cuda")
+
+ args = parser.parse_args()
+
+ cfg.model_path = args.model_path
+ cfg.model_idx = int(args.model_idx)
+
+ cfg.device = args.device
+ infer = ImageSwap(cfg)
+
+ def inference(source_face, target_face, shape_rate, id_rate):
+ return infer.inference(source_face, target_face, shape_rate, id_rate)
+
+ output = gr.Image(shape=None, label="换脸结果")
+ demo = gr.Interface(
+ fn=inference,
+ inputs=[
+ gr.Image(shape=None, label="选脸图"),
+ gr.Image(shape=None, label="目标图"),
+ gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ value=1.0,
+ step=0.1,
+ label="3d结构相似度(1.0表示完全替换)",
+ ),
+ gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ value=1.0,
+ step=0.1,
+ label="人脸特征相似度(1.0表示完全替换)",
+ ),
+ ],
+ outputs=output,
+ title="HiConFace人脸融合系统",
+ description="v1.0: developed by yiwise CV group",
+ )
+ demo.launch(server_name="0.0.0.0", server_port=7860)
+
+ infer.inference()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/benchmark/app_video.py b/benchmark/app_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5c4705599c7602e3688dc2b1dcb2d5cb1dfceca
--- /dev/null
+++ b/benchmark/app_video.py
@@ -0,0 +1,273 @@
+import argparse
+import os
+import uuid
+
+import cv2
+import gradio as gr
+import kornia
+import numpy as np
+import torch
+from loguru import logger
+from torchaudio.io import StreamReader
+from torchaudio.io import StreamWriter
+
+from benchmark.face_pipeline import alignFace
+from benchmark.face_pipeline import FaceDetector
+from benchmark.face_pipeline import inverse_transform_batch
+from benchmark.face_pipeline import SoftErosion
+from configs.train_config import TrainConfig
+from models.model import HifiFace
+
+
+class VideoSwap:
+ def __init__(self, cfg, model=None):
+ self.facedetector = FaceDetector(cfg.face_detector_weights)
+ self.alignface = alignFace()
+ self.work_dir = "."
+ opt = TrainConfig()
+ opt.use_ddp = False
+ self.device = "cuda"
+ self.ffmpeg_device = cfg.ffmpeg_device
+ self.num_frames = 10
+ self.kps_window = []
+ checkpoint = (cfg.model_path, cfg.model_idx)
+ if model is None:
+ self.model = HifiFace(
+ opt.identity_extractor_config, is_training=False, device=self.device, load_checkpoint=checkpoint
+ )
+ else:
+ self.model = model
+ self.model.eval()
+ os.makedirs(self.work_dir, exist_ok=True)
+ uid = uuid.uuid4()
+ self.swapped_video = os.path.join(self.work_dir, f"tmp_{uid}.mp4")
+
+ # model-idx_image-name_target-video-name.mp4
+ swapped_with_audio_name = f"result_{uid}.mp4"
+
+ # 带有音频的换脸视频
+ self.swapped_video_with_audio = os.path.join(self.work_dir, swapped_with_audio_name)
+
+ self.smooth_mask = SoftErosion(kernel_size=7, threshold=0.9, iterations=7).to(self.device)
+
+ def yuv_to_rgb(self, img):
+ img = img.to(torch.float)
+ y = img[..., 0, :, :]
+ u = img[..., 1, :, :]
+ v = img[..., 2, :, :]
+ y /= 255
+
+ u = u / 255 - 0.5
+ v = v / 255 - 0.5
+
+ r = y + 1.14 * v
+ g = y + -0.396 * u - 0.581 * v
+ b = y + 2.029 * u
+
+ rgb = torch.stack([r, g, b], -1)
+ return rgb
+
+ def rgb_to_yuv(self, img):
+ r = img[..., 0, :, :]
+ g = img[..., 1, :, :]
+ b = img[..., 2, :, :]
+ y = (0.299 * r + 0.587 * g + 0.114 * b) * 255
+ u = (-0.1471 * r - 0.2889 * g + 0.4360 * b) * 255 + 128
+ v = (0.6149 * r - 0.5149 * g - 0.1 * b) * 255 + 128
+ yuv = torch.stack([y, u, v], -1)
+ return torch.clamp(yuv, 0.0, 255.0, out=None).type(dtype=torch.uint8).transpose(3, 2).transpose(2, 1)
+
+ def _geometry_transfrom_warp_affine(self, swapped_image, inv_att_transforms, frame_size, square_mask):
+ swapped_image = kornia.geometry.transform.warp_affine(
+ swapped_image,
+ inv_att_transforms,
+ frame_size,
+ mode="bilinear",
+ padding_mode="border",
+ align_corners=True,
+ fill_value=torch.zeros(3),
+ )
+
+ square_mask = kornia.geometry.transform.warp_affine(
+ square_mask,
+ inv_att_transforms,
+ frame_size,
+ mode="bilinear",
+ padding_mode="zeros",
+ align_corners=True,
+ fill_value=torch.zeros(3),
+ )
+ return swapped_image, square_mask
+
+ def smooth_kps(self, kps):
+ self.kps_window.append(kps.flatten())
+ self.kps_window = self.kps_window[1:]
+ X = np.stack(self.kps_window, axis=1)
+ y = self.kps_window[-1]
+ y_cor = X @ np.linalg.inv(X.transpose() @ X - 0.0007 * np.eye(self.num_frames)) @ X.transpose() @ y
+ self.kps_window[-1] = y_cor
+ return y_cor.reshape((5, 2))
+
+ def detect_and_align(self, image, src_is=False):
+ detection = self.facedetector(image)
+ if detection.score is None:
+ self.kps_window = []
+ return None, None
+ max_score_ind = np.argmax(detection.score, axis=0)
+ kps = detection.key_points[max_score_ind]
+ if len(self.kps_window) < self.num_frames:
+ self.kps_window.append(kps.flatten())
+ else:
+ kps = self.smooth_kps(kps)
+ align_img, warp_mat = self.alignface.align_face(image, kps, 256)
+ align_img = cv2.resize(align_img, (256, 256))
+ align_img = align_img.transpose(2, 0, 1)
+ align_img = torch.from_numpy(align_img).unsqueeze(0).to(self.device).float()
+ align_img = align_img / 255.0
+ if src_is:
+ self.kps_window = []
+ return align_img, warp_mat
+
+ def inference(self, source_face, target_video, shape_rate, id_rate, iterations=1):
+ video = cv2.VideoCapture(target_video)
+ # 获取视频宽度
+ frame_width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
+ # 获取视频高度
+ frame_height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ # 获取帧率
+ frame_rate = int(video.get(cv2.CAP_PROP_FPS))
+ video.release()
+ self.frame_size = (frame_height, frame_width)
+ if self.ffmpeg_device == "cuda":
+ self.decode_config = {"frames_per_chunk": 1, "decoder": "h264", "format": "yuv444p"}
+ # self.decode_config = {
+ # "frames_per_chunk": 1,
+ # "decoder": "h264_cuvid",
+ # "decoder_option": {"gpu": "0"},
+ # "hw_accel": "cuda:0",
+ # }
+
+ self.encode_config = {
+ "encoder": "h264_nvenc", # GPU Encoder
+ "encoder_format": "yuv444p",
+ "encoder_option": {"gpu": "0", "cq": "10"}, # Run encoding on the cuda:0 device
+ "hw_accel": "cuda:0", # Data comes from cuda:0 device
+ "frame_rate": frame_rate,
+ "height": frame_height,
+ "width": frame_width,
+ "format": "yuv444p",
+ }
+ else:
+ self.decode_config = {"frames_per_chunk": 1, "decoder": "h264", "format": "yuv444p"}
+
+ self.encode_config = {
+ "encoder": "libx264",
+ "encoder_format": "yuv444p",
+ "frame_rate": frame_rate,
+ "height": frame_height,
+ "width": frame_width,
+ "format": "yuv444p",
+ }
+ src = source_face
+ src, _ = self.detect_and_align(src, src_is=True)
+ logger.info("start swapping")
+ sr = StreamReader(target_video)
+ if self.ffmpeg_device == "cpu":
+ sr.add_basic_video_stream(**self.decode_config)
+ else:
+ sr.add_basic_video_stream(**self.decode_config)
+ # sr.add_video_stream(**self.decode_config)
+ sw = StreamWriter(self.swapped_video)
+ sw.add_video_stream(**self.encode_config)
+ with sw.open():
+ for (chunk,) in sr.stream():
+ # StreamReader cuda decode颜色格式默认为yuv需要转为rgb
+ chunk = self.yuv_to_rgb(chunk)
+ image = (chunk * 255).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
+ align_img, warp_mat = self.detect_and_align(image)
+ chunk = chunk.transpose(3, 2).transpose(2, 1).to(self.device)
+ if align_img is None:
+ result_face = chunk
+ else:
+ with torch.no_grad():
+ for _ in range(iterations):
+ swapped_face, m_r = self.model.forward(src, align_img, shape_rate, id_rate)
+ swapped_face = torch.clamp(swapped_face, 0, 1)
+ align_img = swapped_face
+ smooth_face_mask, _ = self.smooth_mask(m_r)
+ warp_mat = torch.from_numpy(warp_mat).float().unsqueeze(0)
+ inverse_warp_mat = inverse_transform_batch(warp_mat)
+ swapped_face, smooth_face_mask = self._geometry_transfrom_warp_affine(
+ swapped_face, inverse_warp_mat, self.frame_size, smooth_face_mask
+ )
+ result_face = (1 - smooth_face_mask) * chunk + smooth_face_mask * swapped_face
+ result_face = self.rgb_to_yuv(result_face).to(self.ffmpeg_device)
+ sw.write_video_chunk(0, result_face)
+
+ # 将target_video中的音频转移到换脸视频上
+ command = f"ffmpeg -loglevel error -i {self.swapped_video} -i {target_video} -c copy \
+ -map 0 -map 1:1? -y -shortest {self.swapped_video_with_audio}"
+ os.system(command)
+
+ # 删除没有音频的换脸视频
+ os.system(f"rm {self.swapped_video}")
+ return self.swapped_video_with_audio
+
+
+class ConfigPath:
+ face_detector_weights = "/mnt/c/yangguo/useful_ckpt/face_detector/face_detector_scrfd_10g_bnkps.onnx"
+ model_path = ""
+ model_idx = 80000
+ ffmpeg_device = "cuda"
+
+
+def main():
+ cfg = ConfigPath()
+ parser = argparse.ArgumentParser(
+ prog="benchmark", description="What the program does", epilog="Text at the bottom of help"
+ )
+ parser.add_argument("-m", "--model_path")
+ parser.add_argument("-i", "--model_idx")
+ parser.add_argument("-f", "--ffmpeg_device")
+
+ args = parser.parse_args()
+
+ cfg.model_path = args.model_path
+ cfg.model_idx = int(args.model_idx)
+ cfg.ffmpeg_device = args.ffmpeg_device
+
+ infer = VideoSwap(cfg)
+
+ def inference(source_face, target_video, shape_rate, id_rate):
+ return infer.inference(source_face, target_video, shape_rate, id_rate)
+
+ output = gr.Video(value=None, label="换脸结果")
+ demo = gr.Interface(
+ fn=inference,
+ inputs=[
+ gr.Image(shape=None, label="选脸图"),
+ gr.Video(value=None, label="目标视频"),
+ gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ value=1.0,
+ step=0.1,
+ label="3d结构相似度(1.0表示完全替换)",
+ ),
+ gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ value=1.0,
+ step=0.1,
+ label="人脸特征相似度(1.0表示完全替换)",
+ ),
+ ],
+ outputs=output,
+ title="HiConFace视频人脸融合系统",
+ description="v1.0: developed by yiwise CV group",
+ )
+ demo.launch(server_name="0.0.0.0", server_port=7860)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/benchmark/face_pipeline.py b/benchmark/face_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..c03839857bac95fdd91bfc5afadb4419d927c932
--- /dev/null
+++ b/benchmark/face_pipeline.py
@@ -0,0 +1,129 @@
+import time
+from pathlib import Path
+from typing import Iterable
+from typing import NamedTuple
+from typing import Optional
+from typing import Tuple
+
+import cv2
+import numpy as np
+import torch
+import torch.nn.functional as F
+from skimage import transform as skt
+
+from .scrfd_detect import SCRFD
+
+# ---frontal
+src = np.array(
+ [
+ [39.730, 51.138],
+ [72.270, 51.138],
+ [56.000, 68.493],
+ [42.463, 87.010],
+ [69.537, 87.010],
+ ],
+ dtype=np.float32,
+)
+
+
+class alignFace:
+ def __init__(self) -> None:
+ self.src_map = src
+
+ def estimate_norm(self, lmk, image_size=112):
+ assert lmk.shape == (5, 2)
+ tform = skt.SimilarityTransform()
+ src_ = self.src_map * image_size / 112
+ tform.estimate(lmk, src_)
+ M = tform.params[0:2, :]
+ return M
+
+ def align_face(
+ self, img: np.ndarray, key_points: np.ndarray, crop_size: int
+ ) -> Tuple[Iterable[np.ndarray], Iterable[np.ndarray]]:
+ transform_matrix = self.estimate_norm(key_points, crop_size)
+ align_img = cv2.warpAffine(img, transform_matrix, (crop_size, crop_size), borderValue=0.0)
+ return align_img, transform_matrix
+
+
+class Detection(NamedTuple):
+ bbox: Optional[np.ndarray]
+ score: Optional[np.ndarray]
+ key_points: Optional[np.ndarray]
+
+
+class FaceDetector:
+ def __init__(
+ self,
+ model_path: Path,
+ det_thresh: float = 0.5,
+ det_size: Tuple[int, int] = (640, 640),
+ mode: str = "None",
+ device: str = "cuda",
+ ):
+ self.det_thresh = det_thresh
+ self.mode = mode
+ self.device = device
+ self.handler = SCRFD(str(model_path), device=self.device, det_thresh=det_thresh)
+ ctx_id = -1 if device == "cpu" else 0
+ self.handler.prepare(ctx_id, input_size=det_size)
+
+ def __call__(self, img: np.ndarray, max_num: int = 0) -> Detection:
+ bboxes, kpss = self.handler.detect(img, max_num=max_num, metric="default")
+ if bboxes.shape[0] == 0:
+ return Detection(None, None, None)
+ return Detection(bboxes[..., :-1], bboxes[..., -1], kpss)
+
+
+def tensor2img(tensor):
+ tensor = tensor.detach().cpu().numpy()
+ img = tensor.transpose(0, 2, 3, 1)[0]
+ img = np.clip(img * 255, 0.0, 255.0).astype(np.uint8)
+ return img
+
+
+def inverse_transform_batch(mat: torch.Tensor, device="cuda") -> torch.Tensor:
+ # inverse the Affine transformation matrix
+ inv_mat = torch.zeros_like(mat).to(device)
+ div1 = mat[:, 0, 0] * mat[:, 1, 1] - mat[:, 0, 1] * mat[:, 1, 0]
+ inv_mat[:, 0, 0] = mat[:, 1, 1] / div1
+ inv_mat[:, 0, 1] = -mat[:, 0, 1] / div1
+ inv_mat[:, 0, 2] = -(mat[:, 0, 2] * mat[:, 1, 1] - mat[:, 0, 1] * mat[:, 1, 2]) / div1
+ div2 = mat[:, 0, 1] * mat[:, 1, 0] - mat[:, 0, 0] * mat[:, 1, 1]
+ inv_mat[:, 1, 0] = mat[:, 1, 0] / div2
+ inv_mat[:, 1, 1] = -mat[:, 0, 0] / div2
+ inv_mat[:, 1, 2] = -(mat[:, 0, 2] * mat[:, 1, 0] - mat[:, 0, 0] * mat[:, 1, 2]) / div2
+ return inv_mat
+
+
+class SoftErosion(torch.nn.Module):
+ def __init__(self, kernel_size: int = 15, threshold: float = 0.6, iterations: int = 1):
+ super(SoftErosion, self).__init__()
+ r = kernel_size // 2
+ self.padding = r
+ self.iterations = iterations
+ self.threshold = threshold
+
+ # Create kernel
+ y_indices, x_indices = torch.meshgrid(torch.arange(0.0, kernel_size), torch.arange(0.0, kernel_size))
+ dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2)
+ kernel = dist.max() - dist
+ kernel /= kernel.sum()
+ kernel = kernel.view(1, 1, *kernel.shape)
+ self.register_buffer("weight", kernel)
+
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ for i in range(self.iterations - 1):
+ x = torch.min(
+ x,
+ F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding),
+ )
+ x = F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding)
+
+ mask = x >= self.threshold
+
+ x[mask] = 1.0
+ # add small epsilon to avoid Nans
+ x[~mask] /= x[~mask].max() + 1e-7
+
+ return x, mask
diff --git a/benchmark/inference_image.py b/benchmark/inference_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb2fb34af9d893a5cdeaf0d64820fad9908439da
--- /dev/null
+++ b/benchmark/inference_image.py
@@ -0,0 +1,150 @@
+import argparse
+import os
+
+import cv2
+import kornia
+import numpy as np
+import torch
+from loguru import logger
+
+from benchmark.face_pipeline import alignFace
+from benchmark.face_pipeline import FaceDetector
+from benchmark.face_pipeline import inverse_transform_batch
+from benchmark.face_pipeline import SoftErosion
+from configs.train_config import TrainConfig
+from models.model import HifiFace
+
+
+class ImageSwap:
+ def __init__(self, cfg):
+ self.source_face = cfg.source_face
+ self.target_face = cfg.target_face
+ self.device = cfg.device
+ self.facedetector = FaceDetector(cfg.face_detector_weights, device=self.device)
+ self.alignface = alignFace()
+ self.work_dir = cfg.work_dir
+ opt = TrainConfig()
+ opt.use_ddp = False
+ checkpoint = (cfg.model_path, cfg.model_idx)
+ self.model = HifiFace(
+ opt.identity_extractor_config, is_training=False, device=self.device, load_checkpoint=checkpoint
+ )
+ self.model.eval()
+ os.makedirs(self.work_dir, exist_ok=True)
+
+ # model-idx_swapped_src-image-name_target-face-name.jpg
+ swapped_image_name = (
+ str(cfg.model_idx)
+ + "_"
+ + "swapped"
+ + "_"
+ + os.path.basename(self.source_face).split(".")[0]
+ + "_"
+ + os.path.basename(self.target_face).split(".")[0]
+ + ".jpg"
+ )
+ self.swapped_image = os.path.join(self.work_dir, swapped_image_name)
+ self.smooth_mask = SoftErosion(kernel_size=7, threshold=0.9, iterations=7).to(self.device)
+
+ def _geometry_transfrom_warp_affine(self, swapped_image, inv_att_transforms, frame_size, square_mask):
+ swapped_image = kornia.geometry.transform.warp_affine(
+ swapped_image,
+ inv_att_transforms,
+ frame_size,
+ mode="bilinear",
+ padding_mode="border",
+ align_corners=True,
+ fill_value=torch.zeros(3),
+ )
+
+ square_mask = kornia.geometry.transform.warp_affine(
+ square_mask,
+ inv_att_transforms,
+ frame_size,
+ mode="bilinear",
+ padding_mode="zeros",
+ align_corners=True,
+ fill_value=torch.zeros(3),
+ )
+ return swapped_image, square_mask
+
+ def detect_and_align(self, image):
+ detection = self.facedetector(image)
+ if detection.score is None:
+ self.kps_window = []
+ return None, None
+ max_score_ind = np.argmax(detection.score, axis=0)
+ kps = detection.key_points[max_score_ind]
+ align_img, warp_mat = self.alignface.align_face(image, kps, 256)
+ align_img = cv2.resize(align_img, (256, 256))
+ align_img = align_img.transpose(2, 0, 1)
+ align_img = torch.from_numpy(align_img).unsqueeze(0).to(self.device).float()
+ align_img = align_img / 255.0
+ return align_img, warp_mat
+
+ def inference(self):
+ src = cv2.cvtColor(cv2.imread(self.source_face), cv2.COLOR_BGR2RGB)
+ src, _ = self.detect_and_align(src)
+ if src is None:
+ print("no face in src_img")
+ return
+ target = cv2.cvtColor(cv2.imread(self.target_face), cv2.COLOR_BGR2RGB)
+ align_target, warp_mat = self.detect_and_align(target)
+ if align_target is None:
+ print("no face in target_img")
+ return
+ logger.info("start swapping")
+ frame_size = (target.shape[0], target.shape[1])
+ with torch.no_grad():
+ swapped_face, m_r = self.model.forward(src, align_target)
+ swapped_face = torch.clamp(swapped_face, 0, 1)
+ smooth_face_mask, _ = self.smooth_mask(m_r)
+ warp_mat = torch.from_numpy(warp_mat).float().unsqueeze(0)
+ inverse_warp_mat = inverse_transform_batch(warp_mat, device=self.device)
+ swapped_face, smooth_face_mask = self._geometry_transfrom_warp_affine(
+ swapped_face, inverse_warp_mat, frame_size, smooth_face_mask
+ )
+ target = torch.from_numpy(target.transpose(2, 0, 1)).unsqueeze(0).to(self.device).float() / 255.0
+ result_face = (1 - smooth_face_mask) * target + smooth_face_mask * swapped_face
+ result_face = torch.clamp(result_face * 255.0, 0.0, 255.0, out=None).type(dtype=torch.uint8)
+ result_face = result_face.detach().cpu().numpy()
+ img = result_face.transpose(0, 2, 3, 1)[0]
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+ cv2.imwrite(self.swapped_image, img)
+
+
+class ConfigPath:
+ source_face = ""
+ target_face = ""
+ work_dir = ""
+ face_detector_weights = "/mnt/c/yangguo/useful_ckpt/face_detector/face_detector_scrfd_10g_bnkps.onnx"
+ model_path = ""
+ model_idx = 80000
+ device = "cuda"
+
+
+def main():
+ cfg = ConfigPath()
+ parser = argparse.ArgumentParser(
+ prog="benchmark", description="What the program does", epilog="Text at the bottom of help"
+ )
+ parser.add_argument("-m", "--model_path")
+ parser.add_argument("-i", "--model_idx")
+ parser.add_argument("-s", "--source_face")
+ parser.add_argument("-t", "--target_face")
+ parser.add_argument("-w", "--work_dir")
+ parser.add_argument("-d", "--device", default="cuda")
+
+ args = parser.parse_args()
+ cfg.source_face = args.source_face
+ cfg.target_face = args.target_face
+ cfg.model_path = args.model_path
+ cfg.model_idx = int(args.model_idx)
+ cfg.work_dir = args.work_dir
+ cfg.device = args.device
+ infer = ImageSwap(cfg)
+ infer.inference()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/benchmark/inference_video.py b/benchmark/inference_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..73755e9bee0e6030e2568bcaeb0b63c3eca354af
--- /dev/null
+++ b/benchmark/inference_video.py
@@ -0,0 +1,250 @@
+import argparse
+import os
+
+import cv2
+import kornia
+import numpy as np
+import torch
+from loguru import logger
+from torchaudio.io import StreamReader
+from torchaudio.io import StreamWriter
+
+from benchmark.face_pipeline import alignFace
+from benchmark.face_pipeline import FaceDetector
+from benchmark.face_pipeline import inverse_transform_batch
+from benchmark.face_pipeline import SoftErosion
+from configs.train_config import TrainConfig
+from models.model import HifiFace
+
+
+class VideoSwap:
+ def __init__(self, cfg):
+ self.source_face = cfg.source_face
+ self.target_video = cfg.target_video
+ self.facedetector = FaceDetector(cfg.face_detector_weights)
+ self.alignface = alignFace()
+ self.work_dir = cfg.work_dir
+ opt = TrainConfig()
+ opt.use_ddp = False
+ self.device = "cuda"
+ self.ffmpeg_device = cfg.ffmpeg_device
+ self.num_frames = 10
+ self.kps_window = []
+ checkpoint = (cfg.model_path, cfg.model_idx)
+ self.model = HifiFace(
+ opt.identity_extractor_config, is_training=False, device=self.device, load_checkpoint=checkpoint
+ )
+ self.model.eval()
+ os.makedirs(self.work_dir, exist_ok=True)
+ self.swapped_video = os.path.join(self.work_dir, "swapped_video.mp4")
+
+ # model-idx_image-name_target-video-name.mp4
+ swapped_with_audio_name = (
+ str(cfg.model_idx)
+ + "_"
+ + os.path.basename(self.source_face).split(".")[0]
+ + "_"
+ + os.path.basename(self.target_video).split(".")[0]
+ + ".mp4"
+ )
+ # 带有音频的换脸视频
+ self.swapped_video_with_audio = os.path.join(self.work_dir, swapped_with_audio_name)
+
+ video = cv2.VideoCapture(self.target_video)
+ # 获取视频宽度
+ frame_width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
+ # 获取视频高度
+ frame_height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ # 获取帧率
+ frame_rate = int(video.get(cv2.CAP_PROP_FPS))
+ video.release()
+ self.frame_size = (frame_height, frame_width)
+
+ if self.ffmpeg_device == "cuda":
+ self.decode_config = {
+ "frames_per_chunk": 1,
+ "decoder": "h264_cuvid",
+ "decoder_option": {"gpu": "0"},
+ "hw_accel": "cuda:0",
+ }
+
+ self.encode_config = {
+ "encoder": "h264_nvenc", # GPU Encoder
+ "encoder_format": "yuv444p",
+ "encoder_option": {"gpu": "0"}, # Run encoding on the cuda:0 device
+ "hw_accel": "cuda:0", # Data comes from cuda:0 device
+ "frame_rate": frame_rate,
+ "height": frame_height,
+ "width": frame_width,
+ "format": "yuv444p",
+ }
+ else:
+ self.decode_config = {"frames_per_chunk": 1, "decoder": "h264", "format": "yuv444p"}
+
+ self.encode_config = {
+ "encoder": "libx264",
+ "encoder_format": "yuv444p",
+ "frame_rate": frame_rate,
+ "height": frame_height,
+ "width": frame_width,
+ "format": "yuv444p",
+ }
+
+ self.smooth_mask = SoftErosion(kernel_size=7, threshold=0.9, iterations=7).to(self.device)
+
+ def yuv_to_rgb(self, img):
+ img = img.to(torch.float)
+ y = img[..., 0, :, :]
+ u = img[..., 1, :, :]
+ v = img[..., 2, :, :]
+ y /= 255
+
+ u = u / 255 - 0.5
+ v = v / 255 - 0.5
+
+ r = y + 1.14 * v
+ g = y + -0.396 * u - 0.581 * v
+ b = y + 2.029 * u
+
+ rgb = torch.stack([r, g, b], -1)
+ return rgb
+
+ def rgb_to_yuv(self, img):
+ r = img[..., 0, :, :]
+ g = img[..., 1, :, :]
+ b = img[..., 2, :, :]
+ y = (0.299 * r + 0.587 * g + 0.114 * b) * 255
+ u = (-0.1471 * r - 0.2889 * g + 0.4360 * b) * 255 + 128
+ v = (0.6149 * r - 0.5149 * g - 0.1 * b) * 255 + 128
+ yuv = torch.stack([y, u, v], -1)
+ return torch.clamp(yuv, 0.0, 255.0, out=None).type(dtype=torch.uint8).transpose(3, 2).transpose(2, 1)
+
+ def _geometry_transfrom_warp_affine(self, swapped_image, inv_att_transforms, frame_size, square_mask):
+ swapped_image = kornia.geometry.transform.warp_affine(
+ swapped_image,
+ inv_att_transforms,
+ frame_size,
+ mode="bilinear",
+ padding_mode="border",
+ align_corners=True,
+ fill_value=torch.zeros(3),
+ )
+
+ square_mask = kornia.geometry.transform.warp_affine(
+ square_mask,
+ inv_att_transforms,
+ frame_size,
+ mode="bilinear",
+ padding_mode="zeros",
+ align_corners=True,
+ fill_value=torch.zeros(3),
+ )
+ return swapped_image, square_mask
+
+ def smooth_kps(self, kps):
+ self.kps_window.append(kps.flatten())
+ self.kps_window = self.kps_window[1:]
+ X = np.stack(self.kps_window, axis=1)
+ y = self.kps_window[-1]
+ y_cor = X @ np.linalg.inv(X.transpose() @ X - 0.0007 * np.eye(self.num_frames)) @ X.transpose() @ y
+ self.kps_window[-1] = y_cor
+ return y_cor.reshape((5, 2))
+
+ def detect_and_align(self, image, src_is=False):
+ detection = self.facedetector(image)
+ if detection.score is None:
+ self.kps_window = []
+ return None, None
+ max_score_ind = np.argmax(detection.score, axis=0)
+ kps = detection.key_points[max_score_ind]
+ if len(self.kps_window) < self.num_frames:
+ self.kps_window.append(kps.flatten())
+ else:
+ kps = self.smooth_kps(kps)
+ align_img, warp_mat = self.alignface.align_face(image, kps, 256)
+ align_img = cv2.resize(align_img, (256, 256))
+ align_img = align_img.transpose(2, 0, 1)
+ align_img = torch.from_numpy(align_img).unsqueeze(0).to(self.device).float()
+ align_img = align_img / 255.0
+ if src_is:
+ self.kps_window = []
+ return align_img, warp_mat
+
+ def inference(self):
+ src = cv2.cvtColor(cv2.imread(self.source_face), cv2.COLOR_BGR2RGB)
+ src, _ = self.detect_and_align(src, src_is=True)
+ logger.info("start swapping")
+ sr = StreamReader(self.target_video)
+ if self.ffmpeg_device == "cpu":
+ sr.add_basic_video_stream(**self.decode_config)
+ else:
+ sr.add_video_stream(**self.decode_config)
+ sw = StreamWriter(self.swapped_video)
+ sw.add_video_stream(**self.encode_config)
+ with sw.open():
+ for (chunk,) in sr.stream():
+ # StreamReader cuda decode颜色格式默认为yuv需要转为rgb
+ chunk = self.yuv_to_rgb(chunk)
+ image = (chunk * 255).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
+ chunk = chunk.transpose(3, 2).transpose(2, 1).to(self.device)
+ align_img, warp_mat = self.detect_and_align(image)
+ if align_img is None:
+ result_face = chunk
+ else:
+ with torch.no_grad():
+ swapped_face, m_r = self.model.forward(src, align_img)
+ swapped_face = torch.clamp(swapped_face, 0, 1)
+ smooth_face_mask, _ = self.smooth_mask(m_r)
+ warp_mat = torch.from_numpy(warp_mat).float().unsqueeze(0)
+ inverse_warp_mat = inverse_transform_batch(warp_mat)
+ swapped_face, smooth_face_mask = self._geometry_transfrom_warp_affine(
+ swapped_face, inverse_warp_mat, self.frame_size, smooth_face_mask
+ )
+ result_face = (1 - smooth_face_mask) * chunk + smooth_face_mask * swapped_face
+ result_face = self.rgb_to_yuv(result_face)
+ sw.write_video_chunk(0, result_face.to(self.ffmpeg_device))
+
+ # 将target_video中的音频转移到换脸视频上
+ command = f"ffmpeg -loglevel error -i {self.swapped_video} -i {self.target_video} -c copy \
+ -map 0 -map 1:1? -y -shortest {self.swapped_video_with_audio}"
+ os.system(command)
+
+ # 删除没有音频的换脸视频
+ os.system(f"rm {self.swapped_video}")
+
+
+class ConfigPath:
+ source_face = ""
+ target_video = ""
+ work_dir = ""
+ face_detector_weights = "/mnt/c/yangguo/useful_ckpt/face_detector/face_detector_scrfd_10g_bnkps.onnx"
+ model_path = ""
+ model_idx = 80000
+ ffmpeg_device = "cuda"
+
+
+def main():
+ cfg = ConfigPath()
+ parser = argparse.ArgumentParser(
+ prog="benchmark", description="What the program does", epilog="Text at the bottom of help"
+ )
+ parser.add_argument("-m", "--model_path")
+ parser.add_argument("-i", "--model_idx")
+ parser.add_argument("-s", "--source_face")
+ parser.add_argument("-t", "--target_video")
+ parser.add_argument("-w", "--work_dir")
+ parser.add_argument("-f", "--ffmpeg_device")
+
+ args = parser.parse_args()
+ cfg.source_face = args.source_face
+ cfg.target_video = args.target_video
+ cfg.model_path = args.model_path
+ cfg.model_idx = int(args.model_idx)
+ cfg.work_dir = args.work_dir
+ cfg.ffmpeg_device = args.ffmpeg_device
+ infer = VideoSwap(cfg)
+ infer.inference()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/benchmark/scrfd_detect.py b/benchmark/scrfd_detect.py
new file mode 100644
index 0000000000000000000000000000000000000000..444cb31fd2f6b6effc86d9a0e8fe10b5d26439ca
--- /dev/null
+++ b/benchmark/scrfd_detect.py
@@ -0,0 +1,363 @@
+# -*- coding: utf-8 -*-
+"""
+@File : scrfd
+@Description: scrfd人脸检测
+@Author: Yang Jian
+@Contact: lian01110@outlook.com
+@Time: 2022/2/25 10:31
+@IDE: PYTHON
+@REFERENCE: https://github.com/yangjian1218
+"""
+from __future__ import division
+
+import datetime
+import os
+import os.path as osp
+import sys
+
+import cv2
+import numpy as np
+import onnx
+import onnxruntime
+from cv2 import KeyPoint
+
+# import face_align
+
+
+def softmax(z):
+ assert len(z.shape) == 2
+ s = np.max(z, axis=1)
+ s = s[:, np.newaxis] # necessary step to do broadcasting
+ e_x = np.exp(z - s)
+ div = np.sum(e_x, axis=1)
+ div = div[:, np.newaxis] # dito
+ return e_x / div
+
+
+def distance2bbox(points, distance, max_shape=None):
+ """Decode distance prediction to bounding box.
+
+ Args:
+ points (Tensor): Shape (n, 2), [x, y].
+ distance (Tensor): Distance from the given point to 4
+ boundaries (left, top, right, bottom).
+ max_shape (tuple): Shape of the image.
+
+ Returns:
+ Tensor: Decoded bboxes.
+ """
+ x1 = points[:, 0] - distance[:, 0]
+ y1 = points[:, 1] - distance[:, 1]
+ x2 = points[:, 0] + distance[:, 2]
+ y2 = points[:, 1] + distance[:, 3]
+ if max_shape is not None:
+ x1 = x1.clamp(min=0, max=max_shape[1])
+ y1 = y1.clamp(min=0, max=max_shape[0])
+ x2 = x2.clamp(min=0, max=max_shape[1])
+ y2 = y2.clamp(min=0, max=max_shape[0])
+ return np.stack([x1, y1, x2, y2], axis=-1)
+
+
+def distance2kps(points, distance, max_shape=None):
+ """Decode distance prediction to bounding box.
+
+ Args:
+ points (Tensor): Shape (n, 2), [x, y].
+ distance (Tensor): Distance from the given point to 4
+ boundaries (left, top, right, bottom).
+ max_shape (tuple): Shape of the image.
+
+ Returns:
+ Tensor: Decoded bboxes.
+ """
+ preds = []
+ for i in range(0, distance.shape[1], 2):
+ px = points[:, i % 2] + distance[:, i]
+ py = points[:, i % 2 + 1] + distance[:, i + 1]
+ if max_shape is not None:
+ px = px.clamp(min=0, max=max_shape[1])
+ py = py.clamp(min=0, max=max_shape[0])
+ preds.append(px)
+ preds.append(py)
+ return np.stack(preds, axis=-1)
+
+
+class SCRFD:
+ def __init__(self, model_file=None, session=None, device="cuda", det_thresh=0.5):
+ self.model_file = model_file
+ self.session = session
+ self.taskname = "detection"
+ if self.session is None:
+ assert self.model_file is not None
+ assert osp.exists(self.model_file)
+ if device == "cpu":
+ providers = ["CPUExecutionProvider"]
+ else:
+ providers = ["CUDAExecutionProvider"]
+ self.session = onnxruntime.InferenceSession(self.model_file, providers=providers)
+ self.center_cache = {}
+ self.nms_thresh = 0.4
+ self.det_thresh = det_thresh
+ self._init_vars()
+
+ def _init_vars(self):
+ input_cfg = self.session.get_inputs()[0]
+ input_shape = input_cfg.shape
+ # print("input_shape:",input_shape)
+ if isinstance(input_shape[2], str):
+ self.input_size = None
+ else:
+ self.input_size = tuple(input_shape[2:4][::-1])
+ # print('image_size:', self.image_size)
+ input_name = input_cfg.name
+ self.input_shape = input_shape
+ outputs = self.session.get_outputs()
+ output_names = []
+ for o in outputs:
+ output_names.append(o.name)
+ self.input_name = input_name
+ self.output_names = output_names
+ # print("input_name:",self.input_name)
+ # print("output_name:",self.output_names)
+ self.input_mean = 127.5
+ self.input_std = 127.5
+ # assert len(outputs)==10 or len(outputs)==15
+ self.use_kps = False
+ self._anchor_ratio = 1.0
+ self._num_anchors = 1
+
+ if len(outputs) == 6:
+ self.fmc = 3
+ self._feat_stride_fpn = [8, 16, 32]
+ self._num_anchors = 2
+ elif len(outputs) == 9:
+ self.fmc = 3
+ self._feat_stride_fpn = [8, 16, 32]
+ self._num_anchors = 2
+ self.use_kps = True
+ elif len(outputs) == 10:
+ self.fmc = 5
+ self._feat_stride_fpn = [8, 16, 32, 64, 128]
+ self._num_anchors = 1
+ elif len(outputs) == 15:
+ self.fmc = 5
+ self._feat_stride_fpn = [8, 16, 32, 64, 128]
+ self._num_anchors = 1
+ self.use_kps = True
+
+ def init_det_threshold(self, det_threshold):
+ """
+ 单独设置人脸检测阈值
+ :param det_threshold: 人脸检测阈值
+ :return:
+ """
+ self.det_thresh = det_threshold
+
+ def prepare(self, ctx_id, **kwargs):
+ if ctx_id < 0:
+ self.session.set_providers(["CPUExecutionProvider"])
+ nms_threshold = kwargs.get("nms_threshold", None)
+ if nms_threshold is not None:
+ self.nms_threshold = nms_threshold
+ input_size = kwargs.get("input_size", None)
+ if input_size is not None:
+ if self.input_size is not None:
+ print("warning: det_size is already set in scrfd model, ignore")
+ else:
+ self.input_size = input_size
+
+ def forward(self, img, threshold=0.6, swap_rb=True):
+ scores_list = []
+ bboxes_list = []
+ kpss_list = []
+ input_size = tuple(img.shape[0:2][::-1])
+ # print('input_size:',input_size)
+ blob = cv2.dnn.blobFromImages(
+ [img], 1.0 / self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=swap_rb
+ )
+ net_outs = self.session.run(self.output_names, {self.input_name: blob})
+ # print("net_outs:::",net_outs[0])
+ input_height = blob.shape[2]
+ input_width = blob.shape[3]
+ fmc = self.fmc # 3
+ for idx, stride in enumerate(self._feat_stride_fpn):
+ scores = net_outs[idx]
+ # print("scores:",scores)
+ bbox_preds = net_outs[idx + fmc]
+ bbox_preds = bbox_preds * stride
+ if self.use_kps:
+ kps_preds = net_outs[idx + fmc * 2] * stride
+ height = input_height // stride
+ width = input_width // stride
+ K = height * width
+ key = (height, width, stride)
+ if key in self.center_cache:
+ anchor_centers = self.center_cache[key]
+ else:
+ # solution-1, c style:
+ # anchor_centers = np.zeros( (height, width, 2), dtype=np.float32 )
+ # for i in range(height):
+ # anchor_centers[i, :, 1] = i
+ # for i in range(width):
+ # anchor_centers[:, i, 0] = i
+
+ # solution-2:
+ # ax = np.arange(width, dtype=np.float32)
+ # ay = np.arange(height, dtype=np.float32)
+ # xv, yv = np.meshgrid(np.arange(width), np.arange(height))
+ # anchor_centers = np.stack([xv, yv], axis=-1).astype(np.float32)
+
+ # solution-3:
+ anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32)
+ # print(anchor_centers.shape)
+
+ anchor_centers = (anchor_centers * stride).reshape((-1, 2))
+ if self._num_anchors > 1:
+ anchor_centers = np.stack([anchor_centers] * self._num_anchors, axis=1).reshape((-1, 2))
+ if len(self.center_cache) < 100:
+ self.center_cache[key] = anchor_centers
+ # print(anchor_centers.shape,bbox_preds.shape,scores.shape,kps_preds.shape)
+ pos_inds = np.where(scores >= threshold)[0]
+ # print("pos_inds:",pos_inds)
+ bboxes = distance2bbox(anchor_centers, bbox_preds)
+ pos_scores = scores[pos_inds]
+ pos_bboxes = bboxes[pos_inds]
+ scores_list.append(pos_scores)
+ bboxes_list.append(pos_bboxes)
+ if self.use_kps:
+ kpss = distance2kps(anchor_centers, kps_preds)
+ # kpss = kps_preds
+ kpss = kpss.reshape((kpss.shape[0], -1, 2))
+ pos_kpss = kpss[pos_inds]
+ kpss_list.append(pos_kpss)
+ # print("....:",bboxes_list)
+ return scores_list, bboxes_list, kpss_list
+
+ def detect(self, img, input_size=None, max_num=0, det_thresh=None, metric="default", swap_rb=True):
+ """
+
+ :param img: 原始图像
+ :param input_size: 输入尺寸,元组或者列表
+ :param max_num: 返回人脸数量, 如果为0,表示所有,
+ :param det_thresh: 人脸检测阈值,
+ :param metric: 排序方式,默认为面积+中心偏移, "max"为面积最大排序
+ :param swap_rb: 是否进行r b通道转换, 如果传入的是bgr格式图片,则需要为True
+ :return:
+ """
+ assert input_size is not None or self.input_size is not None
+ input_size = self.input_size if input_size is None else input_size
+ # resize方法选择,缩小选择cv2.INTER_AREA , 放大选择cv2.INTER_LINEAR
+ resize_interpolation = cv2.INTER_AREA if img.shape[0] >= input_size[0] else cv2.INTER_LINEAR
+ im_ratio = float(img.shape[0]) / img.shape[1]
+ model_ratio = float(input_size[1]) / input_size[0]
+ if im_ratio > model_ratio:
+ new_height = input_size[1]
+ new_width = int(new_height / im_ratio)
+ else:
+ new_width = input_size[0]
+ new_height = int(new_width * im_ratio)
+ det_scale = float(new_height) / img.shape[0]
+ resized_img = cv2.resize(img, (new_width, new_height), interpolation=resize_interpolation)
+ det_img = np.zeros((input_size[1], input_size[0], 3), dtype=np.uint8)
+ det_img[:new_height, :new_width, :] = resized_img
+ if det_thresh == None:
+ det_thresh = self.det_thresh
+ scores_list, bboxes_list, kpss_list = self.forward(det_img, det_thresh, swap_rb)
+ # print("====",len(scores_list),len(bboxes_list),len(kpss_list))
+ # print("scores_list:",scores_list)
+ scores = np.vstack(scores_list)
+ scores_ravel = scores.ravel()
+ order = scores_ravel.argsort()[::-1]
+ bboxes = np.vstack(bboxes_list) / det_scale
+ if self.use_kps:
+ kpss = np.vstack(kpss_list) / det_scale
+ pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False)
+ pre_det = pre_det[order, :]
+ keep = self.nms(pre_det)
+ det = pre_det[keep, :]
+ if self.use_kps:
+ kpss = kpss[order, :, :]
+ kpss = kpss[keep, :, :]
+ else:
+ kpss = None
+ if max_num > 0 and det.shape[0] > max_num:
+ area = (det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1])
+ img_center = img.shape[0] // 2, img.shape[1] // 2
+ offsets = np.vstack(
+ [(det[:, 0] + det[:, 2]) / 2 - img_center[1], (det[:, 1] + det[:, 3]) / 2 - img_center[0]]
+ )
+ offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
+ if metric == "max":
+ values = area
+ else:
+ values = area - offset_dist_squared * 2.0 # some extra weight on the centering
+ bindex = np.argsort(values)[::-1] # some extra weight on the centering
+ bindex = bindex[0:max_num]
+ det = det[bindex, :]
+ if kpss is not None:
+ kpss = kpss[bindex, :]
+ return det, kpss
+
+ def nms(self, dets):
+ thresh = self.nms_thresh
+ x1 = dets[:, 0]
+ y1 = dets[:, 1]
+ x2 = dets[:, 2]
+ y2 = dets[:, 3]
+ scores = dets[:, 4]
+
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+ order = scores.argsort()[::-1]
+
+ keep = []
+ while order.size > 0:
+ i = order[0]
+ keep.append(i)
+ xx1 = np.maximum(x1[i], x1[order[1:]])
+ yy1 = np.maximum(y1[i], y1[order[1:]])
+ xx2 = np.minimum(x2[i], x2[order[1:]])
+ yy2 = np.minimum(y2[i], y2[order[1:]])
+
+ w = np.maximum(0.0, xx2 - xx1 + 1)
+ h = np.maximum(0.0, yy2 - yy1 + 1)
+ inter = w * h
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
+
+ inds = np.where(ovr <= thresh)[0]
+ order = order[inds + 1]
+
+ return keep
+
+
+if __name__ == "__main__":
+
+ detector = SCRFD(
+ model_file="/mnt/c/yangguo/useful_ckpt/face_detector/face_detector_scrfd_10g_bnkps.onnx", device="cpu"
+ )
+ # detector.prepare()
+ img_path = "/mnt/c/yangguo/hififace_infer/src_image/boy.jpg"
+ img = cv2.imread(img_path)
+ ta = datetime.datetime.now()
+ cycle = 100
+ # for i in range(cycle):
+ bboxes, kpss = detector.detect(img, input_size=(640, 640)) # 得到box跟关键点
+ # print("bboxes:",bboxes,"\nkpss:",kpss)
+ tb = datetime.datetime.now()
+ print("all cost:", (tb - ta).total_seconds() * 1000)
+ print(img_path, bboxes.shape)
+ if kpss is not None:
+ print(kpss.shape)
+ # todo 画图
+ for i in range(bboxes.shape[0]):
+ bbox = bboxes[i]
+ x1, y1, x2, y2, score = bbox.astype(np.int32)
+ cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
+ if kpss is not None:
+ kps = kpss[i]
+ for kp in kps:
+ kp = kp.astype(np.int32)
+ cv2.circle(img, tuple(kp), 1, (0, 0, 255), 2)
+ # cv2.namedWindow("img", 2)
+ cv2.imwrite("./img.jpg", img)
+ # cv2.imshow("img", img)
+ # cv2.waitKey(0)
diff --git a/benchmark/test.py b/benchmark/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..80efe31587b47a0af8d991901d2675d0e209d128
--- /dev/null
+++ b/benchmark/test.py
@@ -0,0 +1,220 @@
+import argparse
+import os
+from typing import List
+from typing import Optional
+
+import cv2
+import numpy as np
+import torch
+
+from configs.train_config import TrainConfig
+from models.model import HifiFace
+
+
+def test(
+ data_root: str,
+ result_path: str,
+ source_face: List[str],
+ target_face: List[str],
+ model_path: str,
+ model_idx: Optional[int],
+):
+ opt = TrainConfig()
+ opt.use_ddp = False
+
+ device = "cpu"
+ checkpoint = (model_path, model_idx)
+ model = HifiFace(opt.identity_extractor_config, is_training=False, device=device, load_checkpoint=checkpoint)
+ model.eval()
+
+ results = []
+ for source, target in zip(source_face, target_face):
+ source = os.path.join(data_root, source)
+ target = os.path.join(data_root, target)
+
+ src_img = cv2.imread(source)
+ src_img = cv2.resize(src_img, (256, 256))
+ src = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
+ src = src.transpose(2, 0, 1)
+ src = torch.from_numpy(src).unsqueeze(0).to(device).float()
+ src = src / 255.0
+
+ tgt_img = cv2.imread(target)
+ tgt_img = cv2.resize(tgt_img, (256, 256))
+ tgt = cv2.cvtColor(tgt_img, cv2.COLOR_BGR2RGB)
+ tgt = tgt.transpose(2, 0, 1)
+ tgt = torch.from_numpy(tgt).unsqueeze(0).to(device).float()
+ tgt = tgt / 255.0
+
+ with torch.no_grad():
+ result_face = model.forward(src, tgt).cpu()
+ result_face = torch.clamp(result_face, 0, 1) * 255
+ result_face = result_face.numpy()[0].astype(np.uint8)
+ result_face = result_face.transpose(1, 2, 0)
+
+ result_face = cv2.cvtColor(result_face, cv2.COLOR_BGR2RGB)
+ one_result = np.concatenate((src_img, tgt_img, result_face), axis=0)
+ results.append(one_result)
+ result = np.concatenate(results, axis=1)
+ swapped_face = os.path.join(data_root, result_path)
+ cv2.imwrite(swapped_face, result)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ prog="benchmark", description="What the program does", epilog="Text at the bottom of help"
+ )
+ parser.add_argument("-m", "--model_name")
+ parser.add_argument("-i", "--model_index")
+ args = parser.parse_args()
+ data_root = "/home/xuehongyang/data/face_swap_test"
+
+ model_path = os.path.join("/data/checkpoints/hififace/", args.model_name)
+ model_idx = int(args.model_index)
+
+ name = f"{args.model_name}_{args.model_index}"
+ source = [
+ "male_1.jpg",
+ "male_2.jpg",
+ "female_1.jpg",
+ "female_2.jpg",
+ "male_1.jpg",
+ "male_2.jpg",
+ "female_1.jpg",
+ "female_2.jpg",
+ "female_1.jpg",
+ "female_2.jpg",
+ "test1.jpg",
+ "test1.jpg",
+ "test1.jpg",
+ ]
+ target = [
+ "male_2.jpg",
+ "male_1.jpg",
+ "female_2.jpg",
+ "female_1.jpg",
+ "female_1.jpg",
+ "female_2.jpg",
+ "male_2.jpg",
+ "male_1.jpg",
+ "male_1.jpg",
+ "male_2.jpg",
+ "female_1.jpg",
+ "female_2.jpg",
+ "male_1.jpg",
+ ]
+
+ target_src = os.path.join(data_root, f"../{name}_1.jpg")
+ test(data_root, target_src, source, target, model_path, model_idx)
+
+ source = [
+ "male_2.jpg",
+ "male_1.jpg",
+ "male_1.jpg",
+ "male_2.jpg",
+ "male_1.jpg",
+ "male_2.jpg",
+ "male_1.jpg",
+ "male_2.jpg",
+ "male_1.jpg",
+ "male_2.jpg",
+ "female_2.jpg",
+ "female_1.jpg",
+ "female_2.jpg",
+ "female_1.jpg",
+ "female_2.jpg",
+ "female_1.jpg",
+ "female_2.jpg",
+ "female_1.jpg",
+ "female_2.jpg",
+ "female_1.jpg",
+ "female_2.jpg",
+ "female_1.jpg",
+ "female_2.jpg",
+ "female_1.jpg",
+ ]
+ target = [
+ "male_1.jpg",
+ "male_2.jpg",
+ "minlu_1.jpg",
+ "minlu_2.jpg",
+ "shizong_1.jpg",
+ "shizong_2.jpg",
+ "tianxin_1.jpg",
+ "tianxin_2.jpg",
+ "xiaohui_1.jpg",
+ "xiaohui_2.jpg",
+ "female_1.jpg",
+ "female_2.jpg",
+ "female_3.jpg",
+ "female_4.jpg",
+ "female_5.jpg",
+ "female_6.jpg",
+ "lixia_1.jpg",
+ "lixia_2.jpg",
+ "qq_1.jpg",
+ "qq_2.jpg",
+ "pink_1.jpg",
+ "pink_2.jpg",
+ "xulie_1.jpg",
+ "xulie_2.jpg",
+ ]
+
+ target_src = os.path.join(data_root, f"../{name}_2.jpg")
+ test(data_root, target_src, source, target, model_path, model_idx)
+
+ source = [
+ "male_2.jpg",
+ "male_1.jpg",
+ "shizong_1.jpg",
+ "shizong_2.jpg",
+ "minlu_1.jpg",
+ "minlu_2.jpg",
+ "xiaohui_1.jpg",
+ "xiaohui_2.jpg",
+ "tianxin_1.jpg",
+ "tianxin_2.jpg",
+ "female_2.jpg",
+ "female_1.jpg",
+ "female_5.jpg",
+ "female_6.jpg",
+ "female_3.jpg",
+ "female_4.jpg",
+ "qq_1.jpg",
+ "qq_2.jpg",
+ "pink_1.jpg",
+ "pink_2.jpg",
+ "xulie_1.jpg",
+ "xulie_2.jpg",
+ "lixia_1.jpg",
+ "lixia_2.jpg",
+ ]
+ target = [
+ "male_2.jpg",
+ "male_1.jpg",
+ "minlu_1.jpg",
+ "minlu_2.jpg",
+ "shizong_1.jpg",
+ "shizong_2.jpg",
+ "tianxin_1.jpg",
+ "tianxin_2.jpg",
+ "xiaohui_1.jpg",
+ "xiaohui_2.jpg",
+ "female_1.jpg",
+ "female_2.jpg",
+ "female_3.jpg",
+ "female_4.jpg",
+ "female_5.jpg",
+ "female_6.jpg",
+ "lixia_1.jpg",
+ "lixia_2.jpg",
+ "qq_1.jpg",
+ "qq_2.jpg",
+ "pink_1.jpg",
+ "pink_2.jpg",
+ "xulie_1.jpg",
+ "xulie_2.jpg",
+ ]
+
+ target_src = os.path.join(data_root, f"../{name}_3.jpg")
+ test(data_root, target_src, source, target, model_path, model_idx)
diff --git a/benchmark/test_1tom.py b/benchmark/test_1tom.py
new file mode 100644
index 0000000000000000000000000000000000000000..5db34ee54133ceaeacfa5378899853530fd6d95c
--- /dev/null
+++ b/benchmark/test_1tom.py
@@ -0,0 +1,107 @@
+import argparse
+import os
+from typing import List
+from typing import Optional
+
+import cv2
+import numpy as np
+import torch
+
+from configs.train_config import TrainConfig
+from models.model import HifiFace
+
+
+def test(
+ data_root: str,
+ result_path: str,
+ source_face: List[str],
+ target_face: List[str],
+ model_path: str,
+ model_idx: Optional[int],
+):
+ opt = TrainConfig()
+ opt.use_ddp = False
+
+ device = "cpu"
+ checkpoint = (model_path, model_idx)
+ model = HifiFace(opt.identity_extractor_config, is_training=False, device=device, load_checkpoint=checkpoint)
+ model.eval()
+
+ results = []
+ for source, target in zip(source_face, target_face):
+ source = os.path.join(data_root, source)
+ target = os.path.join(data_root, target)
+
+ src_img = cv2.imread(source)
+ src_img = cv2.resize(src_img, (256, 256))
+ src = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
+ src = src.transpose(2, 0, 1)
+ src = torch.from_numpy(src).unsqueeze(0).to(device).float()
+ src = src / 255.0
+
+ tgt_img = cv2.imread(target)
+ tgt_img = cv2.resize(tgt_img, (256, 256))
+ tgt = cv2.cvtColor(tgt_img, cv2.COLOR_BGR2RGB)
+ tgt = tgt.transpose(2, 0, 1)
+ tgt = torch.from_numpy(tgt).unsqueeze(0).to(device).float()
+ tgt = tgt / 255.0
+
+ with torch.no_grad():
+ result_face = model.forward(src, tgt).cpu()
+ result_face = torch.clamp(result_face, 0, 1) * 255
+ result_face = result_face.numpy()[0].astype(np.uint8)
+ result_face = result_face.transpose(1, 2, 0)
+
+ result_face = cv2.cvtColor(result_face, cv2.COLOR_BGR2RGB)
+ one_result = np.concatenate((src_img, tgt_img, result_face), axis=0)
+ results.append(one_result)
+ result = np.concatenate(results, axis=1)
+ swapped_face = os.path.join(data_root, result_path)
+ cv2.imwrite(swapped_face, result)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ prog="benchmark", description="What the program does", epilog="Text at the bottom of help"
+ )
+ parser.add_argument("-m", "--model_name")
+ parser.add_argument("-i", "--model_index")
+ parser.add_argument("-s", "--source_image")
+ args = parser.parse_args()
+ data_root = "/home/xuehongyang/data/face_swap_test"
+
+ model_path = os.path.join("/data/checkpoints/hififace/", args.model_name)
+ model_idx = int(args.model_index)
+
+ name = f"{args.model_name}_{args.model_index}"
+
+ target = [
+ "male_1.jpg",
+ "male_2.jpg",
+ "minlu_1.jpg",
+ "minlu_2.jpg",
+ "shizong_1.jpg",
+ "shizong_2.jpg",
+ "tianxin_1.jpg",
+ "tianxin_2.jpg",
+ "xiaohui_1.jpg",
+ "xiaohui_2.jpg",
+ "female_1.jpg",
+ "female_2.jpg",
+ "female_3.jpg",
+ "female_4.jpg",
+ "female_5.jpg",
+ "female_6.jpg",
+ "lixia_1.jpg",
+ "lixia_2.jpg",
+ "qq_1.jpg",
+ "qq_2.jpg",
+ "pink_1.jpg",
+ "pink_2.jpg",
+ "xulie_1.jpg",
+ "xulie_2.jpg",
+ ]
+
+ source = [args.source_image] * len(target)
+ target_src = os.path.join(data_root, f"../{name}_1tom_{args.source_image}.jpg")
+ test(data_root, target_src, source, target, model_path, model_idx)
diff --git a/configs/__pycache__/mode.cpython-310.pyc b/configs/__pycache__/mode.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8caeffb1f6ce8a500eb300fa4df50b68c8f93de7
Binary files /dev/null and b/configs/__pycache__/mode.cpython-310.pyc differ
diff --git a/configs/__pycache__/singleton.cpython-310.pyc b/configs/__pycache__/singleton.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..72fcd2f313a7752e155ea7989acb1cd3948d9a4a
Binary files /dev/null and b/configs/__pycache__/singleton.cpython-310.pyc differ
diff --git a/configs/__pycache__/train_config.cpython-310.pyc b/configs/__pycache__/train_config.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..830bdf37b24c23bc4db96dd04cc1a5e8df27b8ce
Binary files /dev/null and b/configs/__pycache__/train_config.cpython-310.pyc differ
diff --git a/configs/mode.py b/configs/mode.py
new file mode 100644
index 0000000000000000000000000000000000000000..64403d31cf15bf14473b054d860344c028609ed9
--- /dev/null
+++ b/configs/mode.py
@@ -0,0 +1,6 @@
+from enum import Enum
+
+
+class FaceSwapMode(Enum):
+ MANY_TO_MANY = 1
+ ONE_TO_MANY = 2
diff --git a/configs/singleton.py b/configs/singleton.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7a74fe00b0d31c4df8029ff2c0d3893c6d7c01f
--- /dev/null
+++ b/configs/singleton.py
@@ -0,0 +1,16 @@
+import functools
+
+
+def Singleton(cls):
+ """
+ 单例decorator
+ """
+ _instance = {}
+
+ @functools.wraps(cls)
+ def _singleton(*args, **kargs):
+ if cls not in _instance:
+ _instance[cls] = cls(*args, **kargs)
+ return _instance[cls]
+
+ return _singleton
diff --git a/configs/train_config.py b/configs/train_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..80048a9cbe0e0920d74be47ef5d57f66ef569b5e
--- /dev/null
+++ b/configs/train_config.py
@@ -0,0 +1,56 @@
+import os
+import time
+from dataclasses import dataclass
+
+from configs.mode import FaceSwapMode
+from configs.singleton import Singleton
+
+
+@Singleton
+@dataclass
+class TrainConfig:
+ mode = FaceSwapMode.MANY_TO_MANY
+ source_name: str = ""
+
+ dataset_index: str = "/data/dataset/faceswap/full.pkl"
+ dataset_root: str = "/data/dataset/faceswap"
+
+ batch_size: int = 8
+ num_threads: int = 8
+ same_rate: float = 0.5
+ lr: float = 5e-5
+ grad_clip: float = 1000.0
+
+ use_ddp: bool = True
+
+ mouth_mask: bool = True
+ eye_hm_loss: bool = False
+ mouth_hm_loss: bool = False
+
+ load_checkpoint = None # ("/data/checkpoints/hififace/rebuilt_discriminator_SFF_c256_1683367464544", 400000)
+
+ identity_extractor_config = {
+ "f_3d_checkpoint_path": "./checkpoints/Deep3DFaceRecon/epoch_20_new.pth",
+ "f_id_checkpoint_path": "./checkpoints/arcface/ms1mv3_arcface_r100_fp16_backbone.pth",
+ "bfm_folder": "./checkpoints/useful_ckpt/BFM",
+ "hrnet_path": "./checkpoints/useful_ckpt/face_98lmks/HR18-WFLW.pth",
+ }
+
+ visualize_interval: int = 100
+ plot_interval: int = 100
+ max_iters: int = 1000000
+ checkpoint_interval: int = 40000
+
+ exp_name: str = "exp_base"
+ log_basedir: str = "/data/logs/hififace/"
+ checkpoint_basedir = "/data/checkpoints/hififace"
+
+ def __post_init__(self):
+ time_stamp = int(time.time() * 1000)
+ self.log_dir = os.path.join(self.log_basedir, f"{self.exp_name}_{time_stamp}")
+ self.checkpoint_dir = os.path.join(self.checkpoint_basedir, f"{self.exp_name}_{time_stamp}")
+
+
+if __name__ == "__main__":
+ tc = TrainConfig()
+ print(tc.log_dir)
diff --git a/data/__pycache__/dataset.cpython-310.pyc b/data/__pycache__/dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6809427f8e046c0356d446001e38d697b8d7c97c
Binary files /dev/null and b/data/__pycache__/dataset.cpython-310.pyc differ
diff --git a/data/dataset.py b/data/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..97333d590c7ee7d2b40ec66beeccf3f52560b58d
--- /dev/null
+++ b/data/dataset.py
@@ -0,0 +1,227 @@
+import os
+import pickle
+import random
+from pathlib import Path
+from typing import Dict
+from typing import List
+
+import torch
+from loguru import logger
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+
+from configs.mode import FaceSwapMode
+from configs.train_config import TrainConfig
+
+
+class ManyToManyTrainDataset(Dataset):
+ def __init__(self, dataset_root: str, dataset_index: str, same_rate=0.5):
+ """
+ Many-to-many 训练数据集构建
+ Parameters:
+ -----------
+ dataset_root: str, 数据集根目录
+ dataset_index: str, 数据集index文件路径
+ same_rate: float, 每个batch里面相同人脸所占的比例
+ """
+ super(ManyToManyTrainDataset, self).__init__()
+ self.transform = transforms.Compose(
+ [
+ transforms.Resize((256, 256)),
+ transforms.CenterCrop((256, 256)),
+ transforms.ToTensor(),
+ ]
+ )
+ self.data_root = Path(dataset_root)
+ with open(dataset_index, "rb") as f:
+ self.file_index = pickle.load(f, encoding="bytes")
+
+ self.same_rate = same_rate
+
+ self.id_list: List[str] = list(self.file_index.keys())
+
+ # 所有id都遍历一遍,视为一个epoch
+ self.length = len(self.id_list)
+ self.image_num = sum([len(v) for v in self.file_index.values()])
+
+ self.mask_dir = "mask" if TrainConfig().mouth_mask else "mask_no_mouth"
+ logger.info(f"dataset contains {self.length} ids and {self.image_num} images")
+ logger.info(f"will use mask mode: {self.mask_dir}")
+
+ def __len__(self):
+ return self.length
+
+ def __getitem__(self, index):
+ source_id_index = index
+ source_file = random.choice(self.file_index[self.id_list[source_id_index]])
+ if random.random() < self.same_rate:
+ # 在相同id的文件列表中选择
+ target_file = random.choice(self.file_index[self.id_list[source_id_index]])
+ same = torch.ones(1)
+ else:
+ # 在不同id的文件列表中选择
+ target_id_index = random.choice(list(set(range(self.length)) - set([source_id_index])))
+ target_file = random.choice(self.file_index[self.id_list[target_id_index]])
+ same = torch.zeros(1)
+
+ source_file = self.data_root / Path(source_file)
+ target_file = self.data_root / Path(target_file)
+ target_mask_file = target_file.parent.parent.parent / self.mask_dir / target_file.parent.stem / target_file.name
+
+ target_img = Image.open(target_file.as_posix()).convert("RGB")
+ source_img = Image.open(source_file.as_posix()).convert("RGB")
+
+ target_mask = Image.open(target_mask_file.as_posix()).convert("RGB")
+
+ source_img = self.transform(source_img)
+ target_img = self.transform(target_img)
+ target_mask = self.transform(target_mask)[0, :, :].unsqueeze(0)
+
+ return {
+ "source_image": source_img,
+ "target_image": target_img,
+ "target_mask": target_mask,
+ "same": same,
+ # "source_img_name": source_file.as_posix(),
+ # "target_img_name": target_file.as_posix(),
+ # "target_mask_name": target_mask_file.as_posix(),
+ }
+
+
+class OneToManyTrainDataset(Dataset):
+ def __init__(self, dataset_root: str, dataset_index: str, source_name: str, same_rate=0.5):
+ """
+ One-to-many 训练数据集构建
+ Parameters:
+ -----------
+ dataset_root: str, 数据集根目录
+ dataset_index: str, 数据集index文件路径
+ source_name: str, source face id的名称, one-to-many里面的one
+ same_rate: float, 每个batch里面相同人脸所占的比例
+ """
+ super(OneToManyTrainDataset, self).__init__()
+ self.transform = transforms.Compose(
+ [
+ transforms.Resize((256, 256)),
+ transforms.CenterCrop((256, 256)),
+ transforms.ToTensor(),
+ ]
+ )
+ self.data_root = Path(dataset_root)
+ with open(dataset_index, "rb") as f:
+ self.file_index = pickle.load(f, encoding="bytes")
+ self.same_rate = same_rate
+ self.source_name = source_name
+
+ self.id_list: List[str] = list(self.file_index.keys())
+
+ try:
+ self.source_id_index: int = self.id_list.index(self.source_name)
+ except Exception:
+ raise Exception(f"{self.source_name} not in dataset dir")
+
+ # 所有id都遍历一遍,视为一个epoch
+ self.length = len(self.id_list)
+ self.image_num = sum([len(v) for v in self.file_index.values()])
+ self.mask_dir = "mask" if TrainConfig().mouth_mask else "mask_no_mouth"
+ logger.info(f"dataset contains {self.length} ids and {self.image_num} images")
+ logger.info(f"will use mask mode: {self.mask_dir}")
+
+ def __len__(self):
+ return self.length
+
+ def __getitem__(self, index):
+ target_id_index = index
+ target_file = random.choice(self.file_index[self.id_list[target_id_index]])
+ if random.random() < self.same_rate:
+ # 在相同id的文件列表中选择
+ source_file = random.choice(self.file_index[self.id_list[target_id_index]])
+ same = torch.ones(1)
+ else:
+ # 直接选择source name中的图片
+ source_file = random.choice(self.file_index[self.source_name])
+ # 如果和target同个id
+ if self.source_id_index == target_id_index:
+ same = torch.ones(1)
+ else:
+ same = torch.zeros(1)
+
+ source_file = self.data_root / Path(source_file)
+ target_file = self.data_root / Path(target_file)
+ target_mask_file = target_file.parent.parent.parent / self.mask_dir / target_file.parent.stem / target_file.name
+
+ target_img = Image.open(target_file.as_posix()).convert("RGB")
+ source_img = Image.open(source_file.as_posix()).convert("RGB")
+
+ target_mask = Image.open(target_mask_file.as_posix()).convert("RGB")
+
+ source_img = self.transform(source_img)
+ target_img = self.transform(target_img)
+ target_mask = self.transform(target_mask)[0, :, :].unsqueeze(0)
+
+ return {
+ "source_image": source_img,
+ "target_image": target_img,
+ "target_mask": target_mask,
+ "same": same,
+ # "source_img_name": source_file.as_posix(),
+ # "target_img_name": target_file.as_posix(),
+ # "target_mask_name": target_mask_file.as_posix(),
+ }
+
+
+class TrainDatasetDataLoader:
+ """Wrapper class of Dataset class that performs multi-threaded data loading"""
+
+ def __init__(self):
+ """Initialize this class"""
+ opt = TrainConfig()
+ if opt.mode is FaceSwapMode.MANY_TO_MANY:
+ self.dataset = ManyToManyTrainDataset(opt.dataset_root, opt.dataset_index, opt.same_rate)
+ elif opt.mode is FaceSwapMode.ONE_TO_MANY:
+ logger.info(f"In one-to-many mode, source face is {opt.source_name}")
+ self.dataset = OneToManyTrainDataset(opt.dataset_root, opt.dataset_index, opt.source_name, opt.same_rate)
+ else:
+ raise NotImplementedError
+ logger.info(f"dataset {type(self.dataset).__name__} created")
+ if opt.use_ddp:
+ self.train_sampler = torch.utils.data.distributed.DistributedSampler(self.dataset, shuffle=True)
+ self.dataloader = torch.utils.data.DataLoader(
+ self.dataset,
+ batch_size=opt.batch_size,
+ num_workers=int(opt.num_threads),
+ drop_last=True,
+ sampler=self.train_sampler,
+ pin_memory=True,
+ )
+ else:
+ self.dataloader = torch.utils.data.DataLoader(
+ self.dataset,
+ batch_size=opt.batch_size,
+ shuffle=True,
+ num_workers=int(opt.num_threads),
+ drop_last=True,
+ pin_memory=True,
+ )
+
+ def load_data(self):
+ return self
+
+ def __len__(self):
+ """Return the number of data in the dataset"""
+ return len(self.dataset)
+
+ def __iter__(self):
+ """Return a batch of data"""
+ for data in self.dataloader:
+ yield data
+
+
+if __name__ == "__main__":
+ dataloader = TrainDatasetDataLoader()
+ for idx, data in enumerate(dataloader):
+ # print(data["source_img_name"])
+ # print(data["target_img_name"])
+ # print(data["target_mask_name"])
+ print(data["same"])
diff --git a/data_process/__pycache__/generate_mask.cpython-310.pyc b/data_process/__pycache__/generate_mask.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..791f94f12def96266b8468c16118da8f286f6a67
Binary files /dev/null and b/data_process/__pycache__/generate_mask.cpython-310.pyc differ
diff --git a/data_process/__pycache__/model.cpython-310.pyc b/data_process/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..86708f8728c5684c4862764abee4fcb5b6d35425
Binary files /dev/null and b/data_process/__pycache__/model.cpython-310.pyc differ
diff --git a/data_process/__pycache__/resnet.cpython-310.pyc b/data_process/__pycache__/resnet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..20956e76f2ae928f5c2fea26849976c257bf58a7
Binary files /dev/null and b/data_process/__pycache__/resnet.cpython-310.pyc differ
diff --git a/data_process/__pycache__/utils.cpython-310.pyc b/data_process/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1e8edcb8747356a84264468d982ad4d89d545515
Binary files /dev/null and b/data_process/__pycache__/utils.cpython-310.pyc differ
diff --git a/data_process/generate_mask.py b/data_process/generate_mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..f76cffc6e0d05c81269614d2c0d328693de6506d
--- /dev/null
+++ b/data_process/generate_mask.py
@@ -0,0 +1,75 @@
+import os
+from pathlib import Path
+
+import cv2
+import torch
+from model import BiSeNet
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm import tqdm
+
+# For BiSeNet and for official_224 SimSwap
+
+
+class MaskDataset(Dataset):
+ def __init__(self, img_root, mask_root):
+ img_dir = Path(img_root)
+ self.to_tensor_normalize = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
+ ]
+ )
+ self.img_files = list(img_dir.glob(f"**/*.jpg"))
+ self.img_files.sort()
+ self.mask_files = [os.path.join(mask_root, os.path.relpath(img_path, img_root)) for img_path in self.img_files]
+
+ def __len__(self):
+ return len(self.mask_files)
+
+ def __getitem__(self, index):
+ img = Image.open(self.img_files[index]).convert("RGB")
+ return {"img": self.to_tensor_normalize(img), "mask_path": self.mask_files[index]}
+
+
+class MaskDataLoader:
+ def __init__(self):
+ """Initialize this class"""
+ self.dataset = MaskDataset(img_root="/data/dataset/face_1k/alignHQ", mask_root="/data/dataset/face_1k/mask")
+
+ self.dataloader = torch.utils.data.DataLoader(
+ self.dataset, batch_size=8, shuffle=True, num_workers=8, drop_last=False
+ )
+
+ def __len__(self):
+ """Return the number of data in the dataset"""
+ return len(self.dataset) / 8
+
+ def __iter__(self):
+ """Return a batch of data"""
+ for data in self.dataloader:
+ yield data
+
+
+if __name__ == "__main__":
+ dataloader = MaskDataLoader()
+ bisenet_path = "/data/useful_ckpt/face_parsing/parsing_model_79999_iter.pth"
+ bisenet = BiSeNet(n_classes=19)
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ bisenet.to(device)
+ state_dict = torch.load(bisenet_path, map_location=device)
+ bisenet.load_state_dict(state_dict)
+ bisenet.eval()
+
+ for data in tqdm(dataloader):
+ mask, ignore_ids = bisenet.get_mask(data["img"].to(device), 256)
+ mask = (mask * 255).to(torch.uint8).cpu().numpy().transpose(0, 2, 3, 1).repeat(3, 3)
+
+ for i in range(mask.shape[0]):
+ if ignore_ids[i]:
+ continue
+ path = data["mask_path"][i]
+ dirname = os.path.dirname(path)
+ os.makedirs(dirname, exist_ok=True)
+ cv2.imwrite(path, mask[i])
diff --git a/data_process/model.py b/data_process/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..8544c409bf858cf8856ef1a61b3c23a869901362
--- /dev/null
+++ b/data_process/model.py
@@ -0,0 +1,307 @@
+#!/usr/bin/python
+# -*- encoding: utf-8 -*-
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from data_process.resnet import Resnet18
+from data_process.utils import encode_segmentation_rgb_batch
+
+
+class ConvBNReLU(nn.Module):
+ def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
+ super(ConvBNReLU, self).__init__()
+ self.conv = nn.Conv2d(
+ in_chan,
+ out_chan,
+ kernel_size=ks,
+ stride=stride,
+ padding=padding,
+ bias=False,
+ )
+ self.bn = nn.BatchNorm2d(out_chan)
+ self.init_weight()
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = F.relu(self.bn(x))
+ return x
+
+ def init_weight(self):
+ for ly in self.children():
+ if isinstance(ly, nn.Conv2d):
+ nn.init.kaiming_normal_(ly.weight, a=1)
+ if ly.bias is not None:
+ nn.init.constant_(ly.bias, 0)
+
+
+class BiSeNetOutput(nn.Module):
+ def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
+ super(BiSeNetOutput, self).__init__()
+ self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
+ self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
+ self.init_weight()
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.conv_out(x)
+ return x
+
+ def init_weight(self):
+ for ly in self.children():
+ if isinstance(ly, nn.Conv2d):
+ nn.init.kaiming_normal_(ly.weight, a=1)
+ if ly.bias is not None:
+ nn.init.constant_(ly.bias, 0)
+
+ def get_params(self):
+ wd_params, nowd_params = [], []
+ for name, module in self.named_modules():
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
+ wd_params.append(module.weight)
+ if module.bias is not None:
+ nowd_params.append(module.bias)
+ elif isinstance(module, nn.BatchNorm2d):
+ nowd_params += list(module.parameters())
+ return wd_params, nowd_params
+
+
+class AttentionRefinementModule(nn.Module):
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
+ super(AttentionRefinementModule, self).__init__()
+ self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
+ self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False)
+ self.bn_atten = nn.BatchNorm2d(out_chan)
+ self.sigmoid_atten = nn.Sigmoid()
+ self.init_weight()
+
+ def forward(self, x):
+ feat = self.conv(x)
+ atten = F.avg_pool2d(feat, feat.size()[2:])
+ atten = self.conv_atten(atten)
+ atten = self.bn_atten(atten)
+ atten = self.sigmoid_atten(atten)
+ out = torch.mul(feat, atten)
+ return out
+
+ def init_weight(self):
+ for ly in self.children():
+ if isinstance(ly, nn.Conv2d):
+ nn.init.kaiming_normal_(ly.weight, a=1)
+ if ly.bias is not None:
+ nn.init.constant_(ly.bias, 0)
+
+
+class ContextPath(nn.Module):
+ def __init__(self, *args, **kwargs):
+ super(ContextPath, self).__init__()
+ self.resnet = Resnet18()
+ self.arm16 = AttentionRefinementModule(256, 128)
+ self.arm32 = AttentionRefinementModule(512, 128)
+ self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
+ self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
+ self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
+
+ self.init_weight()
+
+ def forward(self, x):
+ H0, W0 = x.size()[2:]
+ feat8, feat16, feat32 = self.resnet(x)
+ H8, W8 = feat8.size()[2:]
+ H16, W16 = feat16.size()[2:]
+ H32, W32 = feat32.size()[2:]
+
+ avg = F.avg_pool2d(feat32, feat32.size()[2:])
+ avg = self.conv_avg(avg)
+ avg_up = F.interpolate(avg, (H32, W32), mode="nearest")
+
+ feat32_arm = self.arm32(feat32)
+ feat32_sum = feat32_arm + avg_up
+ feat32_up = F.interpolate(feat32_sum, (H16, W16), mode="nearest")
+ feat32_up = self.conv_head32(feat32_up)
+
+ feat16_arm = self.arm16(feat16)
+ feat16_sum = feat16_arm + feat32_up
+ feat16_up = F.interpolate(feat16_sum, (H8, W8), mode="nearest")
+ feat16_up = self.conv_head16(feat16_up)
+
+ return feat8, feat16_up, feat32_up # x8, x8, x16
+
+ def init_weight(self):
+ for ly in self.children():
+ if isinstance(ly, nn.Conv2d):
+ nn.init.kaiming_normal_(ly.weight, a=1)
+ if ly.bias is not None:
+ nn.init.constant_(ly.bias, 0)
+
+ def get_params(self):
+ wd_params, nowd_params = [], []
+ for name, module in self.named_modules():
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ wd_params.append(module.weight)
+ if module.bias is not None:
+ nowd_params.append(module.bias)
+ elif isinstance(module, nn.BatchNorm2d):
+ nowd_params += list(module.parameters())
+ return wd_params, nowd_params
+
+
+# This is not used, since I replace this with the resnet feature with the same size
+class SpatialPath(nn.Module):
+ def __init__(self, *args, **kwargs):
+ super(SpatialPath, self).__init__()
+ self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
+ self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
+ self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
+ self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
+ self.init_weight()
+
+ def forward(self, x):
+ feat = self.conv1(x)
+ feat = self.conv2(feat)
+ feat = self.conv3(feat)
+ feat = self.conv_out(feat)
+ return feat
+
+ def init_weight(self):
+ for ly in self.children():
+ if isinstance(ly, nn.Conv2d):
+ nn.init.kaiming_normal_(ly.weight, a=1)
+ if ly.bias is not None:
+ nn.init.constant_(ly.bias, 0)
+
+ def get_params(self):
+ wd_params, nowd_params = [], []
+ for name, module in self.named_modules():
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
+ wd_params.append(module.weight)
+ if module.bias is not None:
+ nowd_params.append(module.bias)
+ elif isinstance(module, nn.BatchNorm2d):
+ nowd_params += list(module.parameters())
+ return wd_params, nowd_params
+
+
+class FeatureFusionModule(nn.Module):
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
+ super(FeatureFusionModule, self).__init__()
+ self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
+ self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False)
+ self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False)
+ self.relu = nn.ReLU(inplace=True)
+ self.sigmoid = nn.Sigmoid()
+ self.init_weight()
+
+ def forward(self, fsp, fcp):
+ fcat = torch.cat([fsp, fcp], dim=1)
+ feat = self.convblk(fcat)
+ atten = F.avg_pool2d(feat, feat.size()[2:])
+ atten = self.conv1(atten)
+ atten = self.relu(atten)
+ atten = self.conv2(atten)
+ atten = self.sigmoid(atten)
+ feat_atten = torch.mul(feat, atten)
+ feat_out = feat_atten + feat
+ return feat_out
+
+ def init_weight(self):
+ for ly in self.children():
+ if isinstance(ly, nn.Conv2d):
+ nn.init.kaiming_normal_(ly.weight, a=1)
+ if ly.bias is not None:
+ nn.init.constant_(ly.bias, 0)
+
+ def get_params(self):
+ wd_params, nowd_params = [], []
+ for name, module in self.named_modules():
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
+ wd_params.append(module.weight)
+ if module.bias is not None:
+ nowd_params.append(module.bias)
+ elif isinstance(module, nn.BatchNorm2d):
+ nowd_params += list(module.parameters())
+ return wd_params, nowd_params
+
+
+class BiSeNet(nn.Module):
+ def __init__(self, n_classes, *args, **kwargs):
+ super(BiSeNet, self).__init__()
+ self.cp = ContextPath()
+ # here self.sp is deleted
+ self.ffm = FeatureFusionModule(256, 256)
+ self.conv_out = BiSeNetOutput(256, 256, n_classes)
+ self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
+ self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
+ self.init_weight()
+
+ def get_mask(self, x: torch.Tensor, crop_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
+ x = F.interpolate(x, size=(512, 512))
+
+ parsed_face = self.forward(x)[0]
+
+ parsed_face = torch.argmax(parsed_face, dim=1, keepdim=True)
+
+ parsed_face = encode_segmentation_rgb_batch(parsed_face)
+
+ parsed_face = torch.where(
+ torch.sum(parsed_face, dim=[1, 2, 3], keepdim=True) > 5000,
+ parsed_face,
+ torch.zeros_like(parsed_face),
+ )
+
+ ignore_mask_ids = torch.sum(parsed_face, dim=[1, 2, 3]) == 0
+
+ parsed_face = parsed_face.float().mul_(1 / 255.0)
+
+ parsed_face = F.interpolate(parsed_face, size=(crop_size, crop_size), mode="bilinear")
+
+ parsed_face = torch.sum(parsed_face, dim=1, keepdim=True)
+
+ return parsed_face, ignore_mask_ids
+
+ def forward(self, x):
+ H, W = x.size()[2:]
+ feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
+ feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
+ feat_fuse = self.ffm(feat_sp, feat_cp8)
+
+ feat_out = self.conv_out(feat_fuse)
+ feat_out16 = self.conv_out16(feat_cp8)
+ feat_out32 = self.conv_out32(feat_cp16)
+
+ feat_out = F.interpolate(feat_out, (H, W), mode="bilinear", align_corners=True)
+ feat_out16 = F.interpolate(feat_out16, (H, W), mode="bilinear", align_corners=True)
+ feat_out32 = F.interpolate(feat_out32, (H, W), mode="bilinear", align_corners=True)
+ return feat_out, feat_out16, feat_out32
+
+ def init_weight(self):
+ for ly in self.children():
+ if isinstance(ly, nn.Conv2d):
+ nn.init.kaiming_normal_(ly.weight, a=1)
+ if ly.bias is not None:
+ nn.init.constant_(ly.bias, 0)
+
+ def get_params(self):
+ wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
+ for name, child in self.named_children():
+ child_wd_params, child_nowd_params = child.get_params()
+ if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
+ lr_mul_wd_params += child_wd_params
+ lr_mul_nowd_params += child_nowd_params
+ else:
+ wd_params += child_wd_params
+ nowd_params += child_nowd_params
+ return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
+
+
+if __name__ == "__main__":
+ net = BiSeNet(19)
+ net.cuda()
+ net.eval()
+ in_ten = torch.randn(16, 3, 640, 480).cuda()
+ out, out16, out32 = net(in_ten)
+ print(out.shape)
+
+ net.get_params()
diff --git a/data_process/resnet.py b/data_process/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb1733d12489877db239f454138bf39bc4db0307
--- /dev/null
+++ b/data_process/resnet.py
@@ -0,0 +1,106 @@
+#!/usr/bin/python
+# -*- encoding: utf-8 -*-
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.model_zoo as modelzoo
+
+# from modules.bn import InPlaceABNSync as BatchNorm2d
+
+resnet18_url = "https://download.pytorch.org/models/resnet18-5c106cde.pth"
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+ def __init__(self, in_chan, out_chan, stride=1):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(in_chan, out_chan, stride)
+ self.bn1 = nn.BatchNorm2d(out_chan)
+ self.conv2 = conv3x3(out_chan, out_chan)
+ self.bn2 = nn.BatchNorm2d(out_chan)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = None
+ if in_chan != out_chan or stride != 1:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(out_chan),
+ )
+
+ def forward(self, x):
+ residual = self.conv1(x)
+ residual = F.relu(self.bn1(residual))
+ residual = self.conv2(residual)
+ residual = self.bn2(residual)
+
+ shortcut = x
+ if self.downsample is not None:
+ shortcut = self.downsample(x)
+
+ out = shortcut + residual
+ out = self.relu(out)
+ return out
+
+
+def create_layer_basic(in_chan, out_chan, bnum, stride=1):
+ layers = [BasicBlock(in_chan, out_chan, stride=stride)]
+ for i in range(bnum - 1):
+ layers.append(BasicBlock(out_chan, out_chan, stride=1))
+ return nn.Sequential(*layers)
+
+
+class Resnet18(nn.Module):
+ def __init__(self):
+ super(Resnet18, self).__init__()
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
+ self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
+ self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
+ self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
+ self.init_weight()
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = F.relu(self.bn1(x))
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ feat8 = self.layer2(x) # 1/8
+ feat16 = self.layer3(feat8) # 1/16
+ feat32 = self.layer4(feat16) # 1/32
+ return feat8, feat16, feat32
+
+ def init_weight(self):
+ state_dict = modelzoo.load_url(resnet18_url)
+ self_state_dict = self.state_dict()
+ for k, v in state_dict.items():
+ if "fc" in k:
+ continue
+ self_state_dict.update({k: v})
+ self.load_state_dict(self_state_dict)
+
+ def get_params(self):
+ wd_params, nowd_params = [], []
+ for name, module in self.named_modules():
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ wd_params.append(module.weight)
+ if module.bias is not None:
+ nowd_params.append(module.bias)
+ elif isinstance(module, nn.BatchNorm2d):
+ nowd_params += list(module.parameters())
+ return wd_params, nowd_params
+
+
+if __name__ == "__main__":
+ net = Resnet18()
+ x = torch.randn(16, 3, 224, 224)
+ out = net(x)
+ print(out[0].size())
+ print(out[1].size())
+ print(out[2].size())
+ net.get_params()
diff --git a/data_process/utils.py b/data_process/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..db42917be61e0baca077a9631eaa92c7962a0f37
--- /dev/null
+++ b/data_process/utils.py
@@ -0,0 +1,99 @@
+from typing import Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+
+class SoftErosion(torch.nn.Module):
+ def __init__(self, kernel_size: int = 15, threshold: float = 0.6, iterations: int = 1):
+ super(SoftErosion, self).__init__()
+ r = kernel_size // 2
+ self.padding = r
+ self.iterations = iterations
+ self.threshold = threshold
+
+ # Create kernel
+ y_indices, x_indices = torch.meshgrid(torch.arange(0.0, kernel_size), torch.arange(0.0, kernel_size))
+ dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2)
+ kernel = dist.max() - dist
+ kernel /= kernel.sum()
+ kernel = kernel.view(1, 1, *kernel.shape)
+ self.register_buffer("weight", kernel)
+
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ for i in range(self.iterations - 1):
+ x = torch.min(
+ x,
+ F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding),
+ )
+ x = F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding)
+
+ mask = x >= self.threshold
+
+ x[mask] = 1.0
+ # add small epsilon to avoid Nans
+ x[~mask] /= x[~mask].max() + 1e-7
+
+ return x, mask
+
+
+def encode_segmentation_rgb(segmentation: np.ndarray, no_neck: bool = True) -> np.ndarray:
+ parse = segmentation
+ # https://github.com/zllrunning/face-parsing.PyTorch/blob/master/prepropess_data.py
+ face_part_ids = [1, 2, 3, 4, 5, 6, 10, 12, 13] if no_neck else [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14]
+ mouth_id = 11
+ # hair_id = 17
+ face_map = np.zeros([parse.shape[0], parse.shape[1]])
+ mouth_map = np.zeros([parse.shape[0], parse.shape[1]])
+ # hair_map = np.zeros([parse.shape[0], parse.shape[1]])
+
+ for valid_id in face_part_ids:
+ valid_index = np.where(parse == valid_id)
+ face_map[valid_index] = 255
+ valid_index = np.where(parse == mouth_id)
+ mouth_map[valid_index] = 255
+ # valid_index = np.where(parse==hair_id)
+ # hair_map[valid_index] = 255
+ # return np.stack([face_map, mouth_map,hair_map], axis=2)
+ return np.stack([face_map, mouth_map], axis=2)
+
+
+def encode_segmentation_rgb_batch(segmentation: torch.Tensor, no_neck: bool = True) -> torch.Tensor:
+ # https://github.com/zllrunning/face-parsing.PyTorch/blob/master/prepropess_data.py
+ face_part_ids = [1, 2, 3, 4, 5, 6, 10, 12, 13] if no_neck else [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14]
+ mouth_id = 11
+ # hair_id = 17
+ segmentation = segmentation.int()
+ face_map = torch.zeros_like(segmentation)
+ mouth_map = torch.zeros_like(segmentation)
+ # hair_map = np.zeros([parse.shape[0], parse.shape[1]])
+
+ white_tensor = face_map + 255
+ for valid_id in face_part_ids:
+ face_map = torch.where(segmentation == valid_id, white_tensor, face_map)
+ mouth_map = torch.where(segmentation == mouth_id, white_tensor, mouth_map)
+
+ return torch.cat([face_map, mouth_map], dim=1)
+
+
+def postprocess(
+ swapped_face: np.ndarray,
+ target: np.ndarray,
+ target_mask: np.ndarray,
+ smooth_mask: torch.nn.Module,
+) -> np.ndarray:
+ # target_mask = cv2.resize(target_mask, (self.size, self.size))
+
+ mask_tensor = torch.from_numpy(target_mask.copy().transpose((2, 0, 1))).float().mul_(1 / 255.0).cuda()
+ face_mask_tensor = mask_tensor[0] + mask_tensor[1]
+
+ soft_face_mask_tensor, _ = smooth_mask(face_mask_tensor.unsqueeze_(0).unsqueeze_(0))
+ soft_face_mask_tensor.squeeze_()
+
+ soft_face_mask = soft_face_mask_tensor.cpu().numpy()
+ soft_face_mask = soft_face_mask[:, :, np.newaxis]
+
+ result = swapped_face * soft_face_mask + target * (1 - soft_face_mask)
+ result = result[:, :, ::-1] # .astype(np.uint8)
+ return result
diff --git a/entry/__pycache__/train.cpython-310.pyc b/entry/__pycache__/train.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cdffad657a201479c74bc40b82fb5d88e6f42a28
Binary files /dev/null and b/entry/__pycache__/train.cpython-310.pyc differ
diff --git a/entry/inference.py b/entry/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..79c2dfa7585b7323b54746847c25d24e03138742
--- /dev/null
+++ b/entry/inference.py
@@ -0,0 +1,48 @@
+from typing import Optional
+
+import cv2
+import numpy as np
+import torch
+
+from configs.train_config import TrainConfig
+from models.model import HifiFace
+
+
+def inference(source_face: str, target_face: str, model_path: str, model_idx: Optional[int], swapped_face: str):
+ opt = TrainConfig()
+ opt.use_ddp = False
+
+ device = "cpu"
+ checkpoint = (model_path, model_idx)
+ model = HifiFace(opt.identity_extractor_config, is_training=False, device=device, load_checkpoint=checkpoint)
+ model.eval()
+
+ src = cv2.cvtColor(cv2.imread(source_face), cv2.COLOR_BGR2RGB)
+ src = cv2.resize(src, (256, 256))
+ src = src.transpose(2, 0, 1)
+ src = torch.from_numpy(src).unsqueeze(0).to(device).float()
+ src = src / 255.0
+
+ tgt = cv2.cvtColor(cv2.imread(target_face), cv2.COLOR_BGR2RGB)
+ tgt = cv2.resize(tgt, (256, 256))
+ tgt = tgt.transpose(2, 0, 1)
+ tgt = torch.from_numpy(tgt).unsqueeze(0).to(device).float()
+ tgt = tgt / 255.0
+
+ with torch.no_grad():
+ result_face = model.forward(src, tgt).cpu()
+ result_face = torch.clamp(result_face, 0, 1) * 255
+ result_face = result_face.numpy()[0].astype(np.uint8)
+ result_face = result_face.transpose(1, 2, 0)
+
+ result_face = cv2.cvtColor(result_face, cv2.COLOR_BGR2RGB)
+ cv2.imwrite(swapped_face, result_face)
+
+
+if __name__ == "__main__":
+ source_face = "/home/xuehongyang/data/female_1.jpg"
+ target_face = "/home/xuehongyang/data/female_2.jpg"
+ model_path = "/data/checkpoints/hififace/baseline_1k_ddp_with_cyc_1681278017147"
+ model_idx = 80000
+ swapped_face = "/home/xuehongyang/data/male_1_to_male_2.jpg"
+ inference(source_face, target_face, model_path, model_idx, swapped_face)
diff --git a/entry/train.py b/entry/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..79b9e6f6e2ef5915938ffa83ed60d8444dba9dfa
--- /dev/null
+++ b/entry/train.py
@@ -0,0 +1,96 @@
+import os
+import sys
+
+import torch
+from loguru import logger
+
+from configs.train_config import TrainConfig
+from data.dataset import TrainDatasetDataLoader
+from models.model import HifiFace
+from utils.visualizer import Visualizer
+
+use_ddp = TrainConfig().use_ddp
+if use_ddp:
+
+ import torch.distributed as dist
+
+ def setup():
+ # os.environ["MASTER_ADDR"] = "localhost"
+ # os.environ["MASTER_PORT"] = "12345"
+ dist.init_process_group("nccl") # , rank=rank, world_size=world_size)
+ return dist.get_rank()
+
+ def cleanup():
+ dist.destroy_process_group()
+
+
+def train():
+ rank = 0
+ if use_ddp:
+ rank = setup()
+ device = torch.device(f"cuda:{rank}")
+ logger.info(f"use device {device}")
+
+ opt = TrainConfig()
+ dataloader = TrainDatasetDataLoader()
+ dataset_length = len(dataloader)
+ logger.info(f"Dataset length: {dataset_length}")
+
+ model = HifiFace(
+ opt.identity_extractor_config, is_training=True, device=device, load_checkpoint=opt.load_checkpoint
+ )
+ model.train()
+
+ logger.info("model initialized")
+ visualizer = None
+ ckpt = False
+ if not opt.use_ddp or rank == 0:
+ visualizer = Visualizer(opt)
+ ckpt = True
+
+ total_iter = 0
+ epoch = 0
+ while True:
+ if opt.use_ddp:
+ dataloader.train_sampler.set_epoch(epoch)
+ for data in dataloader:
+ source_image = data["source_image"].to(device)
+ target_image = data["target_image"].to(device)
+ targe_mask = data["target_mask"].to(device)
+ same = data["same"].to(device)
+ loss_dict, visual_dict = model.optimize(source_image, target_image, targe_mask, same)
+
+ total_iter += 1
+
+ if total_iter % opt.visualize_interval == 0 and visualizer is not None:
+ visualizer.display_current_results(total_iter, visual_dict)
+
+ if total_iter % opt.plot_interval == 0 and visualizer is not None:
+ visualizer.plot_current_losses(total_iter, loss_dict)
+ logger.info(f"Iter: {total_iter}")
+ for k, v in loss_dict.items():
+ logger.info(f" {k}: {v}")
+ logger.info("=" * 20)
+
+ if total_iter % opt.checkpoint_interval == 0 and ckpt:
+ logger.info(f"Saving model at iter {total_iter}")
+ model.save(opt.checkpoint_dir, total_iter)
+
+ if total_iter > opt.max_iters:
+ logger.info(f"Maximum iterations exceeded. Stopping training.")
+ if ckpt:
+ model.save(opt.checkpoint_dir, total_iter)
+ if use_ddp:
+ cleanup()
+ sys.exit(0)
+ epoch += 1
+
+
+if __name__ == "__main__":
+ if use_ddp:
+ # CUDA_VISIBLE_DEVICES=2,3 torchrun --nnodes=1 --nproc_per_node=2 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=127.0.0.1:29400 -m entry.train
+ os.environ["OMP_NUM_THREADS"] = "1"
+ n_gpus = torch.cuda.device_count()
+ train()
+ else:
+ train()
diff --git a/models/__pycache__/discriminator.cpython-310.pyc b/models/__pycache__/discriminator.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e62cc2dc33963898d5153f920b0f6c42facd2d03
Binary files /dev/null and b/models/__pycache__/discriminator.cpython-310.pyc differ
diff --git a/models/__pycache__/gan_loss.cpython-310.pyc b/models/__pycache__/gan_loss.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..83a33733a7f5c43a24da4bca73bb819315f03b18
Binary files /dev/null and b/models/__pycache__/gan_loss.cpython-310.pyc differ
diff --git a/models/__pycache__/generator.cpython-310.pyc b/models/__pycache__/generator.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2708d710e6e39006d14a982e184a7983a1427cbc
Binary files /dev/null and b/models/__pycache__/generator.cpython-310.pyc differ
diff --git a/models/__pycache__/init_weight.cpython-310.pyc b/models/__pycache__/init_weight.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..65a8cd4b6043293284db8f4a99f4116360c561b8
Binary files /dev/null and b/models/__pycache__/init_weight.cpython-310.pyc differ
diff --git a/models/__pycache__/model.cpython-310.pyc b/models/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2e90308a535317f73d02b79173081c19cbfec57e
Binary files /dev/null and b/models/__pycache__/model.cpython-310.pyc differ
diff --git a/models/__pycache__/model_blocks.cpython-310.pyc b/models/__pycache__/model_blocks.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3f33353e9bb811bcedffdf184acef2d05ab17fc4
Binary files /dev/null and b/models/__pycache__/model_blocks.cpython-310.pyc differ
diff --git a/models/__pycache__/semantic_face_fusion_model.cpython-310.pyc b/models/__pycache__/semantic_face_fusion_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6d808e60441d0f7f52c3b7e478e9a763f41b26fb
Binary files /dev/null and b/models/__pycache__/semantic_face_fusion_model.cpython-310.pyc differ
diff --git a/models/__pycache__/shape_aware_identity_model.cpython-310.pyc b/models/__pycache__/shape_aware_identity_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a6a1eeb95b9b4f83a679d38a42f0c0f0a49c51dc
Binary files /dev/null and b/models/__pycache__/shape_aware_identity_model.cpython-310.pyc differ
diff --git a/models/discriminator.py b/models/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..f13161f5b401eff3c063739550f6636e5b53f39a
--- /dev/null
+++ b/models/discriminator.py
@@ -0,0 +1,26 @@
+import numpy as np
+import torch.nn as nn
+
+from models.model_blocks import ResBlock
+
+
+class Discriminator(nn.Module):
+ def __init__(self, input_nc, ndf=64, n_layers=6):
+ super(Discriminator, self).__init__()
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=3, stride=1, padding=1)]
+ for i in range(n_layers):
+ if i >= 3:
+ sequence += [ResBlock(512, 512, down_sample=True, norm=False)]
+ else:
+ mult = 2**i
+ sequence += [ResBlock(ndf * mult, ndf * mult * 2, down_sample=True, norm=False)]
+ sequence += [
+ nn.Conv2d(512, 512, kernel_size=4, stride=1, padding=0),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(512, 2, kernel_size=1, stride=1, padding=0),
+ nn.LeakyReLU(0.2, inplace=True),
+ ]
+ self.sequence = nn.Sequential(*sequence)
+
+ def forward(self, input):
+ return self.sequence(input)
diff --git a/models/gan_loss.py b/models/gan_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..28bf698f69c51bb206ee304f08e5d840eb7c76c7
--- /dev/null
+++ b/models/gan_loss.py
@@ -0,0 +1,45 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class GANLoss(nn.Module):
+ def __init__(self, target_real_label=1.0, target_fake_label=0.0, tensor=torch.FloatTensor, opt=None):
+ super(GANLoss, self).__init__()
+ self.real_label = target_real_label
+ self.fake_label = target_fake_label
+ self.real_label_tensor = None
+ self.fake_label_tensor = None
+ self.zero_tensor = None
+ self.Tensor = tensor
+ self.opt = opt
+
+ def get_target_tensor(self, input, target_is_real):
+ if target_is_real:
+ return torch.ones_like(input).detach()
+ else:
+ return torch.zeros_like(input).detach()
+
+ def get_zero_tensor(self, input):
+ return torch.zeros_like(input).detach()
+
+ def loss(self, inputs, target_is_real, for_discriminator=True):
+ target_tensor = self.get_target_tensor(inputs, target_is_real)
+ loss = F.binary_cross_entropy_with_logits(inputs, target_tensor)
+ return loss
+
+ def __call__(self, inputs, target_is_real, for_discriminator=True):
+ # computing loss is a bit complicated because |input| may not be
+ # a tensor, but list of tensors in case of multiscale discriminator
+ if isinstance(inputs, list):
+ loss = 0
+ for pred_i in inputs:
+ if isinstance(pred_i, list):
+ pred_i = pred_i[-1]
+ loss_tensor = self.loss(pred_i, target_is_real, for_discriminator)
+ bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
+ new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
+ loss += new_loss
+ return loss / len(inputs)
+ else:
+ return self.loss(inputs, target_is_real, for_discriminator)
diff --git a/models/generator.py b/models/generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..36609df0f80d90d94152a21b78abfc4f5679f0f7
--- /dev/null
+++ b/models/generator.py
@@ -0,0 +1,121 @@
+import torch
+import torch.nn as nn
+
+from models.init_weight import init_net
+from models.model_blocks import AdaInResBlock
+from models.model_blocks import ResBlock
+from models.semantic_face_fusion_model import SemanticFaceFusionModule
+from models.shape_aware_identity_model import ShapeAwareIdentityExtractor
+
+
+class Encoder(nn.Module):
+ """
+ Hififace encoder part
+ """
+
+ def __init__(self):
+ super(Encoder, self).__init__()
+ self.conv_first = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
+
+ self.channel_list = [64, 128, 256, 512, 512, 512, 512, 512]
+ self.down_sample = [True, True, True, True, True, False, False]
+
+ self.block_list = nn.ModuleList()
+
+ for i in range(7):
+ self.block_list.append(
+ ResBlock(self.channel_list[i], self.channel_list[i + 1], down_sample=self.down_sample[i])
+ )
+
+ def forward(self, x):
+ x = self.conv_first(x)
+ z_enc = None
+
+ for i in range(7):
+ x = self.block_list[i](x)
+ if i == 1:
+ z_enc = x
+ return z_enc, x
+
+
+class Decoder(nn.Module):
+ """
+ Hififace decoder part
+ """
+
+ def __init__(self):
+ super(Decoder, self).__init__()
+ self.block_list = nn.ModuleList()
+ self.channel_list = [512, 512, 512, 512, 512, 256]
+ self.up_sample = [False, False, True, True, True]
+
+ for i in range(5):
+ self.block_list.append(
+ AdaInResBlock(self.channel_list[i], self.channel_list[i + 1], up_sample=self.up_sample[i])
+ )
+
+ def forward(self, x, id_vector):
+ """
+ Parameters:
+ -----------
+ x: encoder encoded feature map
+ id_vector: 3d shape aware identity vector
+
+ Returns:
+ --------
+ z_dec
+ """
+ for i in range(5):
+ x = self.block_list[i](x, id_vector)
+ return x
+
+
+class Generator(nn.Module):
+ """
+ Hififace Generator
+ """
+
+ def __init__(self, identity_extractor_config):
+ super(Generator, self).__init__()
+ self.id_extractor = ShapeAwareIdentityExtractor(identity_extractor_config)
+ self.id_extractor.requires_grad_(False)
+ self.encoder = init_net(Encoder())
+ self.decoder = init_net(Decoder())
+ self.sff_module = init_net(SemanticFaceFusionModule())
+
+ @torch.no_grad()
+ def interp(self, i_source, i_target, shape_rate=1.0, id_rate=1.0):
+ shape_aware_id_vector = self.id_extractor.interp(i_source, i_target, shape_rate, id_rate)
+ z_enc, x = self.encoder(i_target)
+ z_dec = self.decoder(x, shape_aware_id_vector)
+
+ i_r, i_low, m_r, m_low = self.sff_module(i_target, z_enc, z_dec, shape_aware_id_vector)
+
+ return i_r, i_low, m_r, m_low
+
+ def forward(self, i_source, i_target, need_id_grad=False):
+ """
+ Parameters:
+ -----------
+ i_source: torch.Tensor, shape (B, 3, H, W), in range [0, 1], source face image
+ i_target: torch.Tensor, shape (B, 3, H, W), in range [0, 1], target face image
+ need_id_grad: bool, whether to calculate id extractor module's gradient
+
+ Returns:
+ --------
+ i_r: torch.Tensor
+ i_low: torch.Tensor
+ m_r: torch.Tensor
+ m_low: torch.Tensor
+ """
+ if need_id_grad:
+ shape_aware_id_vector = self.id_extractor(i_source, i_target)
+ else:
+ with torch.no_grad():
+ shape_aware_id_vector = self.id_extractor(i_source, i_target)
+ z_enc, x = self.encoder(i_target)
+ z_dec = self.decoder(x, shape_aware_id_vector)
+
+ i_r, i_low, m_r, m_low = self.sff_module(i_target, z_enc, z_dec, shape_aware_id_vector)
+
+ return i_r, i_low, m_r, m_low
diff --git a/models/init_weight.py b/models/init_weight.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f7f22e93b67b95340309b504c7da40d80df1e64
--- /dev/null
+++ b/models/init_weight.py
@@ -0,0 +1,57 @@
+import torch
+from torch.nn import init
+
+
+def init_weights(net, init_type="normal", init_gain=0.02):
+ """Initialize network weights.
+
+ Parameters:
+ net (network) -- network to be initialized
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
+
+ We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
+ work better for some applications. Feel free to try yourself.
+ """
+
+ def init_func(m): # define the initialization function
+ classname = m.__class__.__name__
+ if hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1):
+ if init_type == "normal":
+ init.normal_(m.weight.data, 0.0, init_gain)
+ elif init_type == "xavier":
+ init.xavier_normal_(m.weight.data, gain=init_gain)
+ elif init_type == "kaiming":
+ init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
+ elif init_type == "orthogonal":
+ init.orthogonal_(m.weight.data, gain=init_gain)
+ else:
+ raise NotImplementedError("initialization method [%s] is not implemented" % init_type)
+ if hasattr(m, "bias") and m.bias is not None:
+ init.constant_(m.bias.data, 0.0)
+ elif (
+ classname.find("BatchNorm2d") != -1
+ ): # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
+ init.normal_(m.weight.data, 1.0, init_gain)
+ init.constant_(m.bias.data, 0.0)
+
+ # print("initialize network with %s" % init_type)
+ net.apply(init_func) # apply the initialization function
+
+
+def init_net(net, init_type="normal", init_gain=0.02, gpu_ids=[]):
+ """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
+ Parameters:
+ net (network) -- the network to be initialized
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
+ gain (float) -- scaling factor for normal, xavier and orthogonal.
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
+
+ Return an initialized network.
+ """
+ if len(gpu_ids) > 0:
+ assert torch.cuda.is_available()
+ net.to(gpu_ids[0])
+ # net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
+ init_weights(net, init_type, init_gain=init_gain)
+ return net
diff --git a/models/model.py b/models/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d33df1efed9247cb89567c49f8e857b492f8555
--- /dev/null
+++ b/models/model.py
@@ -0,0 +1,435 @@
+import os
+from typing import Dict
+from typing import Optional
+from typing import Tuple
+
+import kornia
+import lpips
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from loguru import logger
+
+from arcface_torch.backbones.iresnet import iresnet100
+from configs.train_config import TrainConfig
+from Deep3DFaceRecon_pytorch.models.bfm import ParametricFaceModel
+from Deep3DFaceRecon_pytorch.models.networks import ReconNetWrapper
+from HRNet.hrnet import HighResolutionNet
+from models.discriminator import Discriminator
+from models.gan_loss import GANLoss
+from models.generator import Generator
+from models.init_weight import init_net
+
+
+class HifiFace:
+ def __init__(
+ self,
+ identity_extractor_config,
+ is_training=True,
+ device="cpu",
+ load_checkpoint: Optional[Tuple[str, int]] = None,
+ ):
+ super(HifiFace, self).__init__()
+ self.generator = Generator(identity_extractor_config)
+ self.is_training = is_training
+
+ if self.is_training:
+ self.lr = TrainConfig().lr
+ self.use_ddp = TrainConfig().use_ddp
+ self.grad_clip = TrainConfig().grad_clip if TrainConfig().grad_clip is not None else 100.0
+
+ self.discriminator = init_net(Discriminator(3))
+
+ self.l1_loss = nn.L1Loss()
+ if TrainConfig().eye_hm_loss or TrainConfig().mouth_hm_loss:
+ self.mse_loss = nn.MSELoss()
+ self.loss_fn_vgg = lpips.LPIPS(net="vgg")
+ self.adv_loss = GANLoss()
+
+ # 3D人脸重建模型
+ self.f_3d = ReconNetWrapper(net_recon="resnet50", use_last_fc=False)
+ self.f_3d.load_state_dict(
+ torch.load(identity_extractor_config["f_3d_checkpoint_path"], map_location="cpu")["net_recon"]
+ )
+ self.f_3d.eval()
+ self.face_model = ParametricFaceModel(bfm_folder=identity_extractor_config["bfm_folder"])
+ self.face_model.to("cpu")
+
+ # 人脸识别模型
+ self.f_id = iresnet100(pretrained=False, fp16=False)
+ self.f_id.load_state_dict(torch.load(identity_extractor_config["f_id_checkpoint_path"], map_location="cpu"))
+ self.f_id.eval()
+
+ # mouth heatmap model
+ if TrainConfig().mouth_hm_loss or TrainConfig().eye_hm_loss:
+ self.model_mouth = HighResolutionNet()
+ checkpoint = torch.load(identity_extractor_config["hrnet_path"], map_location="cpu")
+ self.model_mouth.load_state_dict(checkpoint)
+ self.model_mouth.eval()
+
+ self.lambda_adv = 1
+ self.lambda_seg = 100
+ self.lambda_rec = 20
+ self.lambda_cyc = 1
+ self.lambda_lpips = 5
+
+ self.lambda_shape = 0.5
+ self.lambda_id = 5
+ self.lambda_eye_hm = 10000.0
+ self.lambda_mouth_hm = 10000.0
+
+ self.dilation_kernel = torch.ones(5, 5)
+
+ if load_checkpoint is not None:
+ self.load(load_checkpoint[0], load_checkpoint[1])
+
+ self.setup(device)
+
+ def save(self, path, idx=None):
+ os.makedirs(path, exist_ok=True)
+ if idx is None:
+ g_path = os.path.join(path, "generator.pth")
+ d_path = os.path.join(path, "discriminator.pth")
+ else:
+ g_path = os.path.join(path, f"generator_{idx}.pth")
+ d_path = os.path.join(path, f"discriminator_{idx}.pth")
+ if self.use_ddp:
+ torch.save(self.generator.module.state_dict(), g_path)
+ torch.save(self.discriminator.module.state_dict(), d_path)
+ else:
+ torch.save(self.generator.state_dict(), g_path)
+ torch.save(self.discriminator.state_dict(), d_path)
+
+ def load(self, path, idx=None):
+ if idx is None:
+ g_path = os.path.join(path, "generator.pth")
+ d_path = os.path.join(path, "discriminator.pth")
+ else:
+ g_path = os.path.join(path, f"generator_{idx}.pth")
+ d_path = os.path.join(path, f"discriminator_{idx}.pth")
+ logger.info(f"Loading generator from {g_path}")
+ self.generator.load_state_dict(torch.load(g_path, map_location="cpu"))
+ if self.is_training:
+ logger.info(f"Loading discriminator from {d_path}")
+ self.discriminator.load_state_dict(torch.load(d_path, map_location="cpu"))
+
+ def setup(self, device):
+ self.generator.to(device)
+
+ if self.is_training:
+ self.discriminator.to(device)
+ self.l1_loss.to(device)
+ if TrainConfig().eye_hm_loss or TrainConfig().mouth_hm_loss:
+ self.mse_loss.to(device)
+ self.f_3d.to(device)
+ self.f_id.to(device)
+
+ self.loss_fn_vgg.to(device)
+ self.face_model.to(device)
+ self.adv_loss.to(device)
+
+ if TrainConfig().mouth_hm_loss or TrainConfig().eye_hm_loss:
+ self.model_mouth.to(device)
+ self.f_3d.requires_grad_(False)
+ self.f_id.requires_grad_(False)
+ self.loss_fn_vgg.requires_grad_(False)
+ if TrainConfig().mouth_hm_loss or TrainConfig().eye_hm_loss:
+ self.model_mouth.requires_grad_(False)
+ self.dilation_kernel = self.dilation_kernel.to(device)
+ if self.use_ddp:
+ from torch.nn.parallel import DistributedDataParallel as DDP
+ import torch.distributed as dist
+
+ self.generator = DDP(self.generator, device_ids=[device])
+ self.discriminator = DDP(self.discriminator, device_ids=[device])
+
+ if dist.get_rank() == 0:
+ torch.save(self.generator.state_dict(), "/tmp/generator.pth")
+ torch.save(self.discriminator.state_dict(), "/tmp/discriminator.pth")
+
+ dist.barrier()
+ self.generator.load_state_dict(torch.load("/tmp/generator.pth", map_location=device))
+ self.discriminator.load_state_dict(torch.load("/tmp/discriminator.pth", map_location=device))
+
+ self.g_optimizer = torch.optim.AdamW(self.generator.parameters(), lr=self.lr, betas=[0, 0.999])
+ self.d_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr=self.lr, betas=[0, 0.999])
+
+ def train(self):
+ self.generator.train()
+ self.discriminator.train()
+ # 整个id extractor是不训练的模块
+ if self.use_ddp:
+ self.generator.module.id_extractor.eval()
+ else:
+ self.generator.id_extractor.eval()
+
+ def eval(self):
+ self.generator.eval()
+ if self.is_training:
+ self.discriminator.eval()
+
+ def train_forward_generator(self, source_img, target_img, target_mask, same_id_mask):
+ """
+ 训练时候 Generator的loss计算
+ Parameters:
+ -----------
+ source_img: torch.Tensor
+ target_img: torch.Tensor
+ target_mask: torch.Tensor, [B, 1, H, W]
+ same_id_mask: torch.Tensor, [B, 1]
+
+ Returns:
+ --------
+ source_img: torch.Tensor
+ target_img: torch.Tensor
+ i_cycle: torch.Tensor, cycle image
+ i_r: torch.Tensor
+ m_r: torch.Tensor
+ loss: Dict[torch.Tensor], contain pairs of loss name and loss values
+ """
+ same = same_id_mask.unsqueeze(-1).unsqueeze(-1)
+ i_r, i_low, m_r, m_low = self.generator(source_img, target_img, need_id_grad=False)
+ i_cycle, _, _, _ = self.generator(target_img, i_r, need_id_grad=True)
+ d_r = self.discriminator(i_r)
+
+ # SID Loss: shape loss + id loss
+
+ with torch.no_grad():
+ c_s = self.f_3d(F.interpolate(source_img, size=224, mode="bilinear"))
+ c_t = self.f_3d(F.interpolate(target_img, size=224, mode="bilinear"))
+ c_r = self.f_3d(F.interpolate(i_r, size=224, mode="bilinear"))
+ c_low = self.f_3d(F.interpolate(i_low, size=224, mode="bilinear"))
+ with torch.no_grad():
+ c_fuse = torch.cat((c_s[:, :80], c_t[:, 80:]), dim=1)
+ _, _, _, q_fuse = self.face_model.compute_for_render(c_fuse)
+ _, _, _, q_r = self.face_model.compute_for_render(c_r)
+ _, _, _, q_low = self.face_model.compute_for_render(c_low)
+ with torch.no_grad():
+ v_id_i_s = F.normalize(
+ self.f_id(F.interpolate((source_img - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2
+ )
+
+ v_id_i_r = F.normalize(self.f_id(F.interpolate((i_r - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2)
+ v_id_i_low = F.normalize(self.f_id(F.interpolate((i_low - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2)
+ loss_shape = self.l1_loss(q_fuse, q_r) + self.l1_loss(q_fuse, q_low)
+ loss_shape = torch.clamp(loss_shape, min=0.0, max=10.0)
+
+ inner_product_r = torch.bmm(v_id_i_s.unsqueeze(1), v_id_i_r.unsqueeze(2)).squeeze()
+ inner_product_low = torch.bmm(v_id_i_s.unsqueeze(1), v_id_i_low.unsqueeze(2)).squeeze()
+ loss_id = self.l1_loss(torch.ones_like(inner_product_r), inner_product_r) + self.l1_loss(
+ torch.ones_like(inner_product_low), inner_product_low
+ )
+ loss_sid = self.lambda_shape * loss_shape + self.lambda_id * loss_id
+
+ # Realism Loss: segmentation loss + reconstruction loss + cycle loss + perceptual loss + adversarial loss
+
+ loss_cycle = self.l1_loss(target_img, i_cycle)
+
+ # dilate target mask
+ target_mask = kornia.morphology.dilation(target_mask, self.dilation_kernel)
+
+ loss_segmentation = self.l1_loss(
+ F.interpolate(target_mask, scale_factor=0.25, mode="bilinear"), m_low
+ ) + self.l1_loss(target_mask, m_r)
+
+ loss_reconstruction = self.l1_loss(i_r * same, target_img * same) + self.l1_loss(
+ i_low * same, F.interpolate(target_img, scale_factor=0.25, mode="bilinear") * same
+ )
+
+ loss_perceptual = self.loss_fn_vgg(target_img * same, i_r * same).mean()
+
+ loss_adversarial = self.adv_loss(d_r, True, for_discriminator=False)
+
+ loss_realism = (
+ self.lambda_adv * loss_adversarial
+ + self.lambda_seg * loss_segmentation
+ + self.lambda_rec * loss_reconstruction
+ + self.lambda_cyc * loss_cycle
+ + self.lambda_lpips * loss_perceptual
+ )
+
+ # eye hm loss
+ loss_eye_hm = 0
+ # mouth hm loss
+ loss_mouth_hm = 0
+ if TrainConfig().eye_hm_loss or TrainConfig().mouth_hm_loss:
+ target_hm = self.model_mouth(target_img)
+ r_hm = self.model_mouth(i_r)
+
+ if TrainConfig().eye_hm_loss:
+ target_eye_hm = target_hm[:, 96:98, :, :]
+ r_eye_hm = r_hm[:, 96:98, :, :]
+ loss_eye_hm = self.mse_loss(r_eye_hm, target_eye_hm)
+ loss_realism = loss_realism + self.lambda_eye_hm * loss_eye_hm
+
+ if TrainConfig().mouth_hm_loss:
+ target_mouth_hm = target_hm[:, 76:96, :, :]
+ r_mouth_hm = r_hm[:, 76:96, :, :]
+ loss_mouth_hm = self.mse_loss(r_mouth_hm, target_mouth_hm)
+ loss_realism = loss_realism + self.lambda_mouth_hm * loss_mouth_hm
+
+ loss_generator = loss_sid + loss_realism
+
+ loss_dict = {
+ "loss_shape": loss_shape,
+ "loss_id": loss_id,
+ "loss_sid": loss_sid,
+ "loss_cycle": loss_cycle,
+ "loss_segmentation": loss_segmentation,
+ "loss_reconstruction": loss_reconstruction,
+ "loss_perceptual": loss_perceptual,
+ "loss_adversarial": loss_adversarial,
+ "loss_realism": loss_realism,
+ "loss_generator": loss_generator,
+ }
+ if TrainConfig().eye_hm_loss:
+ loss_dict.update({"loss_eye_hm": loss_eye_hm})
+ if TrainConfig().mouth_hm_loss:
+ loss_dict.update({"loss_mouth_hm": loss_mouth_hm})
+ return (
+ source_img,
+ target_img,
+ i_cycle.detach(),
+ i_r.detach(),
+ m_r.detach(),
+ loss_dict,
+ )
+
+ def train_forward_discriminator(self, target_img, i_r):
+ """
+ 训练时候 Discriminator的loss计算
+ Parameters:
+ -----------
+ target_img: torch.Tensor, 目标脸图片
+ i_r: torch.Tensor, 换脸结果
+
+ Returns:
+ --------
+ Dict[str]: contains pair of loss name and loss values
+ """
+ d_gt = self.discriminator(target_img)
+ d_fake = self.discriminator(i_r.detach())
+ loss_real = self.adv_loss(d_gt, True)
+ loss_fake = self.adv_loss(d_fake, False)
+
+ # alpha = torch.rand(target_img.shape[0], 1, 1, 1).to(target_img.device)
+ # x_hat = (alpha * target_img.data + (1 - alpha) * i_r.data).requires_grad_(True)
+ # out = self.discriminator(x_hat)
+ # loss_gp = gradient_penalty(out, x_hat)
+
+ loss_discriminator = loss_real + loss_fake # + 10 * loss_gp
+ return {
+ "loss_real": loss_real,
+ "loss_fake": loss_fake,
+ # "loss_gp": loss_gp,
+ "loss_discriminator": loss_discriminator,
+ }
+
+ def forward(
+ self, source_img: torch.Tensor, target_img: torch.Tensor, shape_rate=None, id_rate=None
+ ) -> torch.Tensor:
+ """
+ Parameters:
+ -----------
+ source_img: torch.Tensor, source face 图像
+ target_img: torch.Tensor, target face 图像
+ *_rate: 插值系数
+ Returns:
+ --------
+ i_r: torch.Tensor, swapped result
+ """
+ if shape_rate is None and id_rate is None:
+ i_r, _, m_r, _ = self.generator(source_img, target_img)
+ else:
+ if shape_rate is None:
+ shape_rate = 1.0
+ if id_rate is None:
+ id_rate = 1.0
+ i_r, _, m_r, _ = self.generator.interp(source_img, target_img, shape_rate, id_rate)
+ return i_r, m_r
+
+ def optimize(
+ self,
+ source_img: torch.Tensor,
+ target_img: torch.Tensor,
+ target_mask: torch.Tensor,
+ same_id_mask: torch.Tensor,
+ ) -> Tuple[Dict, Dict[str, torch.Tensor]]:
+ """
+ 模型的optimize
+ 训练模式下执行一次训练,并返回loss信息和结果
+ Parameters:
+ -----------
+ source_img: torch.Tensor, source face 图像
+ target_img: torch.Tensor, target face 图像
+ target_mask: torch.Tensor, target face mask
+ same_id_mask: torch.Tensor, same id mask, 标识source 和 target是否是同个人
+
+ Returns:
+ --------
+ Tuple[Dict, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ loss_dict, source_img, target_img, m_r(预测的mask), i_r(换脸结果)
+ """
+ src_img, tgt_img, i_cycle, i_r, m_r, loss_G_dict = self.train_forward_generator(
+ source_img, target_img, target_mask, same_id_mask
+ )
+ loss_G = loss_G_dict["loss_generator"]
+ self.g_optimizer.zero_grad()
+ loss_G.backward()
+ global_norm_G = torch.nn.utils.clip_grad_norm_(self.generator.parameters(), self.grad_clip)
+ self.g_optimizer.step()
+
+ loss_D_dict = self.train_forward_discriminator(tgt_img, i_r)
+ loss_D = loss_D_dict["loss_discriminator"]
+ self.d_optimizer.zero_grad()
+ loss_D.backward()
+ global_norm_D = torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.grad_clip)
+ self.d_optimizer.step()
+
+ total_loss_dict = {"global_norm_G": global_norm_G, "global_norm_D": global_norm_D}
+ total_loss_dict.update(loss_G_dict)
+ total_loss_dict.update(loss_D_dict)
+
+ return total_loss_dict, {
+ "source face": src_img,
+ "target face": tgt_img,
+ "swapped face": torch.clamp(i_r, min=0.0, max=1.0),
+ "pred face mask": m_r,
+ "cycle face": i_cycle,
+ }
+
+
+if __name__ == "__main__":
+ import torch
+ import cv2
+ from configs.train_config import TrainConfig
+
+ identity_extractor_config = TrainConfig().identity_extractor_config
+
+ model = HifiFace(identity_extractor_config, is_training=True)
+
+ # src = cv2.imread("/home/xuehongyang/data/test1.jpg")
+ # tgt = cv2.imread("/home/xuehongyang/data/test2.jpg")
+ # src = cv2.cvtColor(src, cv2.COLOR_BGR2RGB)
+ # tgt = cv2.cvtColor(tgt, cv2.COLOR_BGR2RGB)
+ # src = cv2.resize(src, (256, 256))
+ # tgt = cv2.resize(tgt, (256, 256))
+ # src = src.transpose(2, 0, 1)[None, ...]
+ # tgt = tgt.transpose(2, 0, 1)[None, ...]
+ # source_img = torch.from_numpy(src).float() / 255.0
+ # target_img = torch.from_numpy(tgt).float() / 255.0
+ # same_id_mask = torch.Tensor([1]).unsqueeze(0)
+ # tgt_mask = target_img[:, 0, :, :].unsqueeze(1)
+ # if torch.cuda.is_available():
+ # model.to("cuda:3")
+ # source_img = source_img.to("cuda:3")
+ # target_img = target_img.to("cuda:3")
+ # tgt_mask = tgt_mask.to("cuda:3")
+ # same_id_mask = same_id_mask.to("cuda:3")
+ # source_img = source_img.repeat(16, 1, 1, 1)
+ # target_img = target_img.repeat(16, 1, 1, 1)
+ # tgt_mask = tgt_mask.repeat(16, 1, 1, 1)
+ # same_id_mask = same_id_mask.repeat(16, 1)
+ # while True:
+ # x = model.optimize(source_img, target_img, tgt_mask, same_id_mask)
+ # print(x[0]["loss_generator"])
diff --git a/models/model_blocks.py b/models/model_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..533307f91dd4368029d69106b8860ceeca8f4d5e
--- /dev/null
+++ b/models/model_blocks.py
@@ -0,0 +1,122 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ResBlock(nn.Module):
+ def __init__(self, in_channel, out_channel, down_sample=False, up_sample=False, norm=True):
+ super(ResBlock, self).__init__()
+
+ main_module_list = []
+ if norm:
+ main_module_list += [
+ nn.InstanceNorm2d(in_channel),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1),
+ ]
+ else:
+ main_module_list += [
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1),
+ ]
+ if down_sample:
+ main_module_list.append(nn.AvgPool2d(kernel_size=2))
+ elif up_sample:
+ main_module_list.append(nn.Upsample(scale_factor=2, mode="bilinear"))
+ if norm:
+ main_module_list += [
+ nn.InstanceNorm2d(out_channel),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1),
+ ]
+ else:
+ main_module_list += [
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1),
+ ]
+ self.main_path = nn.Sequential(*main_module_list)
+
+ side_module_list = [nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0)]
+ if down_sample:
+ side_module_list.append(nn.AvgPool2d(kernel_size=2))
+ elif up_sample:
+ side_module_list.append(nn.Upsample(scale_factor=2, mode="bilinear"))
+ self.side_path = nn.Sequential(*side_module_list)
+
+ def forward(self, x):
+ x1 = self.main_path(x)
+ x2 = self.side_path(x)
+ return x1 + x2
+
+
+class AdaIn(nn.Module):
+ def __init__(self, in_channel, vector_size):
+ super(AdaIn, self).__init__()
+ self.eps = 1e-5
+ self.std_style_fc = nn.Linear(vector_size, in_channel)
+ self.mean_style_fc = nn.Linear(vector_size, in_channel)
+
+ def forward(self, x, style_vector):
+ std_style = self.std_style_fc(style_vector)
+ mean_style = self.mean_style_fc(style_vector)
+
+ std_style = std_style.unsqueeze(-1).unsqueeze(-1)
+ mean_style = mean_style.unsqueeze(-1).unsqueeze(-1)
+
+ x = F.instance_norm(x)
+ x = std_style * x + mean_style
+ return x
+
+
+class AdaInResBlock(nn.Module):
+ def __init__(self, in_channel, out_channel, up_sample=False):
+ super(AdaInResBlock, self).__init__()
+ self.vector_size = 257 + 512
+ self.up_sample = up_sample
+
+ self.adain1 = AdaIn(in_channel, self.vector_size)
+ self.adain2 = AdaIn(out_channel, self.vector_size)
+
+ main_module_list = []
+ main_module_list += [
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1),
+ ]
+ if up_sample:
+ main_module_list.append(nn.Upsample(scale_factor=2, mode="bilinear"))
+ self.main_path1 = nn.Sequential(*main_module_list)
+
+ self.main_path2 = nn.Sequential(
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1),
+ )
+
+ side_module_list = [nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0)]
+ if up_sample:
+ side_module_list.append(nn.Upsample(scale_factor=2, mode="bilinear"))
+ self.side_path = nn.Sequential(*side_module_list)
+
+ def forward(self, x, id_vector):
+ x1 = self.adain1(x, id_vector)
+ x1 = self.main_path1(x1)
+ x2 = self.side_path(x)
+
+ x1 = self.adain2(x1, id_vector)
+ x1 = self.main_path2(x1)
+
+ return x1 + x2
+
+
+class UpSamplingBlock(nn.Module):
+ def __init__(
+ self,
+ ):
+ super(UpSamplingBlock, self).__init__()
+ self.net = nn.Sequential(ResBlock(256, 256, up_sample=True), ResBlock(256, 256, up_sample=True))
+ self.i_r_net = nn.Sequential(nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(256, 3, 3, 1, 1))
+ self.m_r_net = nn.Sequential(nn.Conv2d(256, 1, 3, 1, 1), nn.Sigmoid())
+
+ def forward(self, x):
+ x = self.net(x)
+ i_r = self.i_r_net(x)
+ m_r = self.m_r_net(x)
+ return i_r, m_r
diff --git a/models/semantic_face_fusion_model.py b/models/semantic_face_fusion_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b462177f04b3f9df4395975604df3b89ed396b30
--- /dev/null
+++ b/models/semantic_face_fusion_model.py
@@ -0,0 +1,73 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+from models.model_blocks import AdaInResBlock
+from models.model_blocks import ResBlock
+from models.model_blocks import UpSamplingBlock
+
+
+class SemanticFaceFusionModule(nn.Module):
+ def __init__(self):
+ """
+ Semantic Face Fusion Module
+ to preserve lighting and background
+ """
+ super(SemanticFaceFusionModule, self).__init__()
+
+ self.sigma = ResBlock(256, 256)
+ self.low_mask_predict = nn.Sequential(nn.Conv2d(256, 1, 3, 1, 1), nn.Sigmoid())
+ self.z_fuse_block_1 = AdaInResBlock(256, 256)
+ self.z_fuse_block_2 = AdaInResBlock(256, 256)
+
+ self.i_low_block = nn.Sequential(nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(256, 3, 3, 1, 1))
+
+ self.f_up = UpSamplingBlock()
+
+ def forward(self, target_image, z_enc, z_dec, v_sid):
+ """
+ Parameters:
+ ----------
+ target_image: 目标脸图片
+ z_enc: 1/4原图大小的low-level encoder feature map
+ z_dec: 1/4原图大小的low-level decoder feature map
+ v_sid: the 3D shape aware identity vector
+
+ Returns:
+ --------
+ i_r: re-target image
+ i_low: 1/4 size retarget image
+ m_r: face mask
+ m_low: 1/4 size face mask
+ """
+ z_enc = self.sigma(z_enc)
+
+ # 估算z_dec对应的人脸 low-level feature mask
+ m_low = self.low_mask_predict(z_dec)
+
+ # 计算融合的low-level feature map
+ # mask区域使用decoder的low-level特征 + 非mask区域使用encoder的low-level特征
+ z_fuse = m_low * z_dec + (1 - m_low) * z_enc
+
+ z_fuse = self.z_fuse_block_1(z_fuse, v_sid)
+ z_fuse = self.z_fuse_block_2(z_fuse, v_sid)
+
+ i_low = self.i_low_block(z_fuse)
+
+ i_low = m_low * i_low + (1 - m_low) * F.interpolate(target_image, scale_factor=0.25)
+
+ i_r, m_r = self.f_up(z_fuse)
+ i_r = m_r * i_r + (1 - m_r) * target_image
+
+ return i_r, i_low, m_r, m_low
+
+
+if __name__ == "__main__":
+ import torch
+
+ timg = torch.randn(1, 3, 256, 256)
+ z_enc = torch.randn(1, 256, 64, 64)
+ z_dec = torch.randn(1, 256, 64, 64)
+ v_sid = torch.randn(1, 769)
+ model = SemanticFaceFusionModule()
+ i_r, i_low, m_r, m_low = model(timg, z_enc, z_dec, v_sid)
+ print(i_r.shape, i_low.shape, m_r.shape, m_low.shape)
diff --git a/models/shape_aware_identity_model.py b/models/shape_aware_identity_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e98f650751e1e3a5250fd12c4e66c0dd4ef23e5
--- /dev/null
+++ b/models/shape_aware_identity_model.py
@@ -0,0 +1,77 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from arcface_torch.backbones.iresnet import iresnet100
+from Deep3DFaceRecon_pytorch.models.networks import ReconNetWrapper
+
+
+class ShapeAwareIdentityExtractor(nn.Module):
+ def __init__(self, identity_extractor_config):
+ """
+ Shape Aware Identity Extractor
+ Parameters:
+ ----------
+ identity_extractor_config: Dict[str, str]
+ 必须包含以下内容:
+ f_3d_checkpoint_path: str
+ 3D人脸重建模型路径,如"model/Deep3DFaceRecon_pytorch/checkpoints/epoch_20.pth"
+ f_id_checkpoint_path: str
+ arcface人脸识别模型路径
+ 非官方实现用的是https://onedrive.live.com/?authkey=%21AFZjr283nwZHqbA&id=4A83B6B633B029CC%215585&cid=4A83B6B633B029CC/backbone.pth
+ """
+ super(ShapeAwareIdentityExtractor, self).__init__()
+ f_3d_checkpoint_path = identity_extractor_config["f_3d_checkpoint_path"]
+ f_id_checkpoint_path = identity_extractor_config["f_id_checkpoint_path"]
+ # 3D人脸重建模型
+ self.f_3d = ReconNetWrapper(net_recon="resnet50", use_last_fc=False)
+ self.f_3d.load_state_dict(torch.load(f_3d_checkpoint_path, map_location="cpu")["net_recon"])
+ self.f_3d.eval()
+
+ # 人脸识别模型
+ self.f_id = iresnet100(pretrained=False, fp16=False)
+ self.f_id.load_state_dict(torch.load(f_id_checkpoint_path, map_location="cpu"))
+ self.f_id.eval()
+
+ @torch.no_grad()
+ def interp(self, i_source, i_target, shape_rate=1.0, id_rate=1.0):
+ """
+ 插值shape和id信息
+ """
+ c_s = self.f_3d(i_source)
+ c_t = self.f_3d(i_target)
+ c_interp = shape_rate * c_s + (1 - shape_rate) * c_t
+ c_fuse = torch.cat((c_interp[:, :80], c_t[:, 80:]), dim=1)
+ # extract source face identity feature
+ v_s = F.normalize(self.f_id(F.interpolate((i_source - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2)
+ v_t = F.normalize(self.f_id(F.interpolate((i_target - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2)
+ v_id = id_rate * v_s + (1 - id_rate) * v_t
+ # concat new shape feature and source identity
+ v_sid = torch.cat((c_fuse, v_id), dim=1)
+ return v_sid
+
+ def forward(self, i_source, i_target):
+ """
+ Parameters:
+ -----------
+ i_source: torch.Tensor, shape (B, 3, H, W), in range [0, 1], source face image
+ i_target: torch.Tensor, shape (B, 3, H, W), in range [0, 1], target face image
+
+ Returns:
+ --------
+ v_sid: torch.Tensor, fused shape and id features
+ """
+ # regress 3DMM coefficients
+ c_s = self.f_3d(i_source)
+ c_t = self.f_3d(i_target)
+
+ # generate a new 3D face model: source's identity + target's posture and expression
+ # from https://github.com/sicxu/Deep3DFaceRecon_pytorch/blob/f221678d4b49ca35f1275ba60f721ecb38a2cd19/models/networks.py#L85
+ c_fuse = torch.cat((c_s[:, :80], c_t[:, 80:]), dim=1)
+
+ # extract source face identity feature
+ v_id = F.normalize(self.f_id(F.interpolate((i_source - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2)
+
+ # concat new shape feature and source identity
+ v_sid = torch.cat((c_fuse, v_id), dim=1)
+ return v_sid
diff --git a/packages.txt b/packages.txt
new file mode 100644
index 0000000000000000000000000000000000000000..74180cc29e68ecb3d5a948464a9822257dd5725f
--- /dev/null
+++ b/packages.txt
@@ -0,0 +1 @@
+wget
diff --git a/pyrightconfig.json b/pyrightconfig.json
new file mode 100644
index 0000000000000000000000000000000000000000..2c189d5b5a9ade107f78fd11dda2b703d2491800
--- /dev/null
+++ b/pyrightconfig.json
@@ -0,0 +1,15 @@
+{
+ "reportMissingImports": true,
+ "reportMissingTypeStubs": true,
+ "useLibraryCodeForTypes": true,
+ "reportUnusedImport": "warning",
+ "reportUnusedVariable": "warning",
+ "reportDuplicateImport": true,
+ "reportPrivateImportUsage": false,
+ "reportWildcardImportFromLibrary": "warning",
+ "reportTypedDictNotRequiredAccess": false,
+ "reportGeneralTypeIssues": false,
+ "venvPath": "/home/xuehongyang/miniconda3/envs/",
+ "venv": "pytorch-2.0",
+ "stubPath": "/home/xuehongyang/dev_configs/typings"
+}
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0e3ecf9220855088fac25d32b033a9e43f78a223
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,4 @@
+torch
+torchaudio
+kornia
+lpips
diff --git a/results/exp_230901_base_1693564635742_320000_1.jpg b/results/exp_230901_base_1693564635742_320000_1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c1f04d73a6b5bb4ad2cd0b2d272eb5b967806653
Binary files /dev/null and b/results/exp_230901_base_1693564635742_320000_1.jpg differ
diff --git a/results/origan-v0-new-3d-250k-eye-mouth-hm-weight-10k-10k_1685515837755_190000_1.jpg b/results/origan-v0-new-3d-250k-eye-mouth-hm-weight-10k-10k_1685515837755_190000_1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3b69d6ecfe8c6cde5eb9b85ded306917d504b958
Binary files /dev/null and b/results/origan-v0-new-3d-250k-eye-mouth-hm-weight-10k-10k_1685515837755_190000_1.jpg differ
diff --git a/results/p1.png b/results/p1.png
new file mode 100644
index 0000000000000000000000000000000000000000..58e38437c0f304d305f6450e2471b47fe5d82284
Binary files /dev/null and b/results/p1.png differ
diff --git a/results/p2.png b/results/p2.png
new file mode 100644
index 0000000000000000000000000000000000000000..5b0f80667a6ad8cba027c282b1a54dc525d1cd8f
Binary files /dev/null and b/results/p2.png differ
diff --git a/results/p3.png b/results/p3.png
new file mode 100644
index 0000000000000000000000000000000000000000..738b083d2cfbdcb23423a39947931d5475db66a9
Binary files /dev/null and b/results/p3.png differ
diff --git a/results/p4.png b/results/p4.png
new file mode 100644
index 0000000000000000000000000000000000000000..d0ce4bbd951b221b56b996f1f7d1b9361b4cccfa
Binary files /dev/null and b/results/p4.png differ
diff --git a/results/p5.png b/results/p5.png
new file mode 100644
index 0000000000000000000000000000000000000000..4ecb370a81b498cbdd507bb7b7cb7dc65f1add26
Binary files /dev/null and b/results/p5.png differ
diff --git a/utils/__pycache__/visualizer.cpython-310.pyc b/utils/__pycache__/visualizer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..df3b594637063b9e93893062c1ebd6fce7a977a0
Binary files /dev/null and b/utils/__pycache__/visualizer.cpython-310.pyc differ
diff --git a/utils/visualizer.py b/utils/visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..37810fd8ffe715e8d975f839b94f43bc683ede80
--- /dev/null
+++ b/utils/visualizer.py
@@ -0,0 +1,39 @@
+import torch
+from torch.utils.tensorboard import SummaryWriter
+
+
+class Visualizer:
+ """
+ Tensorboard 可视化监控类
+ """
+
+ def __init__(self, opt):
+ """ """
+ self.opt = opt # cache the option
+ self.writer = SummaryWriter(log_dir=opt.log_dir)
+
+ def display_current_results(self, iters, visuals_dict):
+ """
+ Display current images
+
+ Parameters:
+ ----------
+ visuals (OrderedDict) - - dictionary of images to display
+ iters (int) - - the current iteration
+ """
+ for label, image in visuals_dict.items():
+ if image.shape[0] >= 2:
+ image = image[0:2, :, :, :]
+ self.writer.add_images(str(label), (image * 255.0).to(torch.uint8), global_step=iters, dataformats="NCHW")
+
+ def plot_current_losses(self, iters, loss_dict):
+ """
+ Display losses on tensorboard
+
+ Parameters:
+ iters (int) -- current iteration
+ losses (OrderedDict) -- training losses stored in the format of (name, torch.Tensor) pairs
+ """
+ x = iters
+ for k, v in loss_dict.items():
+ self.writer.add_scalar(f"Loss/{k}", v, x)