diff --git a/FaceLandmarkDetection/Dockerfile b/FaceLandmarkDetection/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..39c5d4a5841ee70227428c47c8c02de4a588e150
--- /dev/null
+++ b/FaceLandmarkDetection/Dockerfile
@@ -0,0 +1,33 @@
+# Based on https://github.com/pytorch/pytorch/blob/master/Dockerfile
+FROM nvidia/cuda:9.2-cudnn7-devel-ubuntu16.04
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ cmake \
+ git \
+ curl \
+ vim \
+ ca-certificates \
+ libboost-all-dev \
+ python-qt4 \
+ libjpeg-dev \
+ libpng-dev &&\
+ rm -rf /var/lib/apt/lists/*
+
+RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
+ chmod +x ~/miniconda.sh && \
+ ~/miniconda.sh -b -p /opt/conda && \
+ rm ~/miniconda.sh
+
+ENV PATH /opt/conda/bin:$PATH
+
+RUN conda config --set always_yes yes --set changeps1 no && conda update -q conda
+RUN conda install pytorch torchvision cuda92 -c pytorch
+
+# Install face-alignment package
+WORKDIR /workspace
+RUN chmod -R a+w /workspace
+RUN git clone https://github.com/1adrianb/face-alignment
+WORKDIR /workspace/face-alignment
+RUN pip install -r requirements.txt
+RUN python setup.py install
diff --git a/FaceLandmarkDetection/LICENSE b/FaceLandmarkDetection/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..ed4c6d1af07162de4ff7d0192fdbbf42acd77beb
--- /dev/null
+++ b/FaceLandmarkDetection/LICENSE
@@ -0,0 +1,29 @@
+BSD 3-Clause License
+
+Copyright (c) 2017, Adrian Bulat
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+* Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/FaceLandmarkDetection/README.md b/FaceLandmarkDetection/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..211370c6bc8954d4285ae5575b1a6401e19b9874
--- /dev/null
+++ b/FaceLandmarkDetection/README.md
@@ -0,0 +1,183 @@
+# Face Recognition
+
+Detect facial landmarks from Python using the world's most accurate face alignment network, capable of detecting points in both 2D and 3D coordinates.
+
+Build using [FAN](https://www.adrianbulat.com)'s state-of-the-art deep learning based face alignment method.
+
+

+
+**Note:** The lua version is available [here](https://github.com/1adrianb/2D-and-3D-face-alignment).
+
+For numerical evaluations it is highly recommended to use the lua version which uses indentical models with the ones evaluated in the paper. More models will be added soon.
+
+[](https://opensource.org/licenses/BSD-3-Clause) [](https://travis-ci.com/1adrianb/face-alignment) [](https://anaconda.org/1adrianb/face_alignment)
+[](https://pypi.org/project/face-alignment/)
+
+## Features
+
+#### Detect 2D facial landmarks in pictures
+
+
+
+
+
+```python
+import face_alignment
+from skimage import io
+
+fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False)
+
+input = io.imread('../test/assets/aflw-test.jpg')
+preds = fa.get_landmarks(input)
+```
+
+#### Detect 3D facial landmarks in pictures
+
+
+
+
+
+```python
+import face_alignment
+from skimage import io
+
+fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, flip_input=False)
+
+input = io.imread('../test/assets/aflw-test.jpg')
+preds = fa.get_landmarks(input)
+```
+
+#### Process an entire directory in one go
+
+```python
+import face_alignment
+from skimage import io
+
+fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False)
+
+preds = fa.get_landmarks_from_directory('../test/assets/')
+```
+
+#### Detect the landmarks using a specific face detector.
+
+By default the package will use the SFD face detector. However the users can alternatively use dlib or pre-existing ground truth bounding boxes.
+
+```python
+import face_alignment
+
+# sfd for SFD, dlib for Dlib and folder for existing bounding boxes.
+fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, face_detector='sfd')
+```
+
+#### Running on CPU/GPU
+In order to specify the device (GPU or CPU) on which the code will run one can explicitly pass the device flag:
+
+```python
+import face_alignment
+
+# cuda for CUDA
+fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device='cpu')
+```
+
+Please also see the ``examples`` folder
+
+## Installation
+
+### Requirements
+
+* Python 3.5+ or Python 2.7 (it may work with other versions too)
+* Linux, Windows or macOS
+* pytorch (>=1.0)
+
+While not required, for optimal performance(especially for the detector) it is **highly** recommended to run the code using a CUDA enabled GPU.
+
+### Binaries
+
+The easiest way to install it is using either pip or conda:
+
+| **Using pip** | **Using conda** |
+|------------------------------|--------------------------------------------|
+| `pip install face-alignment` | `conda install -c 1adrianb face_alignment` |
+| | |
+
+Alternatively, bellow, you can find instruction to build it from source.
+
+### From source
+
+ Install pytorch and pytorch dependencies. Instructions taken from [pytorch readme](https://github.com/pytorch/pytorch). For a more updated version check the framework github page.
+
+ On Linux
+```bash
+export CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" # [anaconda root directory]
+
+# Install basic dependencies
+conda install numpy pyyaml mkl setuptools cmake gcc cffi
+
+# Add LAPACK support for the GPU
+conda install -c soumith magma-cuda80 # or magma-cuda75 if CUDA 7.5
+```
+
+On OSX
+```bash
+export CMAKE_PREFIX_PATH=[anaconda root directory]
+conda install numpy pyyaml setuptools cmake cffi
+```
+#### Get the PyTorch source
+```bash
+git clone --recursive https://github.com/pytorch/pytorch
+```
+
+#### Install PyTorch
+On Linux
+```bash
+python setup.py install
+```
+
+On OSX
+```bash
+MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py install
+```
+
+#### Get the Face Alignment source code
+```bash
+git clone https://github.com/1adrianb/face-alignment
+```
+#### Install the Face Alignment lib
+```bash
+pip install -r requirements.txt
+python setup.py install
+```
+
+### Docker image
+
+A Dockerfile is provided to build images with cuda support and cudnn v5. For more instructions about running and building a docker image check the orginal Docker documentation.
+```
+docker build -t face-alignment .
+```
+
+## How does it work?
+
+While here the work is presented as a black-box, if you want to know more about the intrisecs of the method please check the original paper either on arxiv or my [webpage](https://www.adrianbulat.com).
+
+## Contributions
+
+All contributions are welcomed. If you encounter any issue (including examples of images where it fails) feel free to open an issue.
+
+## Citation
+
+```
+@inproceedings{bulat2017far,
+ title={How far are we from solving the 2D \& 3D Face Alignment problem? (and a dataset of 230,000 3D facial landmarks)},
+ author={Bulat, Adrian and Tzimiropoulos, Georgios},
+ booktitle={International Conference on Computer Vision},
+ year={2017}
+}
+```
+
+For citing dlib, pytorch or any other packages used here please check the original page of their respective authors.
+
+## Acknowledgements
+
+* To the [pytorch](http://pytorch.org/) team for providing such an awesome deeplearning framework
+* To [my supervisor](http://www.cs.nott.ac.uk/~pszyt/) for his patience and suggestions.
+* To all other python developers that made available the rest of the packages used in this repository.
diff --git a/FaceLandmarkDetection/docs/images/2dlandmarks.png b/FaceLandmarkDetection/docs/images/2dlandmarks.png
new file mode 100644
index 0000000000000000000000000000000000000000..f815722b4b22066e21fda9210400e86d36019f8a
Binary files /dev/null and b/FaceLandmarkDetection/docs/images/2dlandmarks.png differ
diff --git a/FaceLandmarkDetection/docs/images/face-alignment-adrian.gif b/FaceLandmarkDetection/docs/images/face-alignment-adrian.gif
new file mode 100644
index 0000000000000000000000000000000000000000..b6bb4224b922dfa04b7fa61805af65e56a9a08be
Binary files /dev/null and b/FaceLandmarkDetection/docs/images/face-alignment-adrian.gif differ
diff --git a/FaceLandmarkDetection/face_alignment/__init__.py b/FaceLandmarkDetection/face_alignment/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bae29fd5f85b41e4669302bd2603bc6924eddc7
--- /dev/null
+++ b/FaceLandmarkDetection/face_alignment/__init__.py
@@ -0,0 +1,7 @@
+# -*- coding: utf-8 -*-
+
+__author__ = """Adrian Bulat"""
+__email__ = 'adrian.bulat@nottingham.ac.uk'
+__version__ = '1.0.1'
+
+from .api import FaceAlignment, LandmarksType, NetworkSize
diff --git a/FaceLandmarkDetection/face_alignment/__pycache__/__init__.cpython-36.pyc b/FaceLandmarkDetection/face_alignment/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..26bc9e86be76e79e7a32a0c66f262f2581e87fb5
Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/__pycache__/__init__.cpython-36.pyc differ
diff --git a/FaceLandmarkDetection/face_alignment/__pycache__/__init__.cpython-37.pyc b/FaceLandmarkDetection/face_alignment/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..079f72e8c64729757200ce1dcbb9f62e36db71de
Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/__pycache__/__init__.cpython-37.pyc differ
diff --git a/FaceLandmarkDetection/face_alignment/__pycache__/api.cpython-36.pyc b/FaceLandmarkDetection/face_alignment/__pycache__/api.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c8b0efd010573c3ec4fd536afe90d65547b0cbdc
Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/__pycache__/api.cpython-36.pyc differ
diff --git a/FaceLandmarkDetection/face_alignment/__pycache__/api.cpython-37.pyc b/FaceLandmarkDetection/face_alignment/__pycache__/api.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf659733951be5a40c9a60d2fe54a8eb18bfb97f
Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/__pycache__/api.cpython-37.pyc differ
diff --git a/FaceLandmarkDetection/face_alignment/__pycache__/models.cpython-36.pyc b/FaceLandmarkDetection/face_alignment/__pycache__/models.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e3b480edcd636d96fe427e73d5d1006bd1e91b36
Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/__pycache__/models.cpython-36.pyc differ
diff --git a/FaceLandmarkDetection/face_alignment/__pycache__/models.cpython-37.pyc b/FaceLandmarkDetection/face_alignment/__pycache__/models.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e3407599d10993ad2af893304eb58cc4e94456be
Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/__pycache__/models.cpython-37.pyc differ
diff --git a/FaceLandmarkDetection/face_alignment/__pycache__/utils.cpython-36.pyc b/FaceLandmarkDetection/face_alignment/__pycache__/utils.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5b3f0648a7b85eea4c2d8eedcd5434a40238744b
Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/__pycache__/utils.cpython-36.pyc differ
diff --git a/FaceLandmarkDetection/face_alignment/__pycache__/utils.cpython-37.pyc b/FaceLandmarkDetection/face_alignment/__pycache__/utils.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b85e82aaa020db2e651ee1f7ea67e71195a6f1c3
Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/__pycache__/utils.cpython-37.pyc differ
diff --git a/FaceLandmarkDetection/face_alignment/api.py b/FaceLandmarkDetection/face_alignment/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8a09a3095d08ef20c46550662766000d13d507c
--- /dev/null
+++ b/FaceLandmarkDetection/face_alignment/api.py
@@ -0,0 +1,207 @@
+from __future__ import print_function
+import os
+import torch
+from torch.utils.model_zoo import load_url
+from enum import Enum
+from skimage import io
+from skimage import color
+import numpy as np
+import cv2
+try:
+ import urllib.request as request_file
+except BaseException:
+ import urllib as request_file
+
+from .models import FAN, ResNetDepth
+from .utils import *
+
+
+class LandmarksType(Enum):
+ """Enum class defining the type of landmarks to detect.
+
+ ``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
+ ``_2halfD`` - this points represent the projection of the 3D points into 3D
+ ``_3D`` - detect the points ``(x,y,z)``` in a 3D space
+
+ """
+ _2D = 1
+ _2halfD = 2
+ _3D = 3
+
+
+class NetworkSize(Enum):
+ # TINY = 1
+ # SMALL = 2
+ # MEDIUM = 3
+ LARGE = 4
+
+ def __new__(cls, value):
+ member = object.__new__(cls)
+ member._value_ = value
+ return member
+
+ def __int__(self):
+ return self.value
+
+models_urls = {
+ '2DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/2DFAN4-11f355bf06.pth.tar',
+ '3DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/3DFAN4-7835d9f11d.pth.tar',
+ 'depth': 'https://www.adrianbulat.com/downloads/python-fan/depth-2a464da4ea.pth.tar',
+}
+
+
+class FaceAlignment:
+ def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
+ device='cuda', flip_input=False, face_detector='sfd', verbose=False):
+ self.device = device
+ self.flip_input = flip_input
+ self.landmarks_type = landmarks_type
+ self.verbose = verbose
+
+ network_size = int(network_size)
+
+ if 'cuda' in device:
+ torch.backends.cudnn.benchmark = True
+
+ # Get the face detector
+ face_detector_module = __import__('face_alignment.detection.' + face_detector,
+ globals(), locals(), [face_detector], 0)
+ self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
+
+ # Initialise the face alignemnt networks
+ self.face_alignment_net = FAN(network_size)
+ if landmarks_type == LandmarksType._2D:
+ network_name = '2DFAN-' + str(network_size)
+ else:
+ network_name = '3DFAN-' + str(network_size)
+
+ fan_weights = load_url(models_urls[network_name], map_location=lambda storage, loc: storage)
+ self.face_alignment_net.load_state_dict(fan_weights)
+
+ self.face_alignment_net.to(device)
+ self.face_alignment_net.eval()
+
+ # Initialiase the depth prediciton network
+ if landmarks_type == LandmarksType._3D:
+ self.depth_prediciton_net = ResNetDepth()
+
+ depth_weights = load_url(models_urls['depth'], map_location=lambda storage, loc: storage)
+ depth_dict = {
+ k.replace('module.', ''): v for k,
+ v in depth_weights['state_dict'].items()}
+ self.depth_prediciton_net.load_state_dict(depth_dict)
+
+ self.depth_prediciton_net.to(device)
+ self.depth_prediciton_net.eval()
+
+ def get_landmarks(self, image_or_path, detected_faces=None):
+ """Deprecated, please use get_landmarks_from_image
+
+ Arguments:
+ image_or_path {string or numpy.array or torch.tensor} -- The input image or path to it.
+
+ Keyword Arguments:
+ detected_faces {list of numpy.array} -- list of bounding boxes, one for each face found
+ in the image (default: {None})
+ """
+ return self.get_landmarks_from_image(image_or_path, detected_faces)
+
+ def get_landmarks_from_image(self, image_or_path, detected_faces=None):
+ """Predict the landmarks for each face present in the image.
+
+ This function predicts a set of 68 2D or 3D images, one for each image present.
+ If detect_faces is None the method will also run a face detector.
+
+ Arguments:
+ image_or_path {string or numpy.array or torch.tensor} -- The input image or path to it.
+
+ Keyword Arguments:
+ detected_faces {list of numpy.array} -- list of bounding boxes, one for each face found
+ in the image (default: {None})
+ """
+ if isinstance(image_or_path, str):
+ try:
+ image = io.imread(image_or_path)
+ except IOError:
+ print("error opening file :: ", image_or_path)
+ return None
+ else:
+ image = image_or_path
+
+ if image.ndim == 2:
+ image = color.gray2rgb(image)
+ elif image.ndim == 4:
+ image = image[..., :3]
+
+ if detected_faces is None:
+ detected_faces = self.face_detector.detect_from_image(image[..., ::-1].copy())
+
+ if len(detected_faces) == 0:
+ print("Warning: No faces were detected.")
+ return None
+
+ torch.set_grad_enabled(False)
+ landmarks = []
+ for i, d in enumerate(detected_faces):
+ center = torch.FloatTensor(
+ [d[2] - (d[2] - d[0]) / 2.0, d[3] - (d[3] - d[1]) / 2.0])
+ center[1] = center[1] - (d[3] - d[1]) * 0.12
+ scale = (d[2] - d[0] + d[3] - d[1]) / self.face_detector.reference_scale
+
+ inp = crop(image, center, scale)
+ inp = torch.from_numpy(inp.transpose(
+ (2, 0, 1))).float()
+
+ inp = inp.to(self.device)
+ inp.div_(255.0).unsqueeze_(0)
+
+ out = self.face_alignment_net(inp)[-1].detach()
+ if self.flip_input:
+ out += flip(self.face_alignment_net(flip(inp))
+ [-1].detach(), is_label=True)
+ out = out.cpu()
+
+ pts, pts_img = get_preds_fromhm(out, center, scale)
+ pts, pts_img = pts.view(68, 2) * 4, pts_img.view(68, 2)
+
+ if self.landmarks_type == LandmarksType._3D:
+ heatmaps = np.zeros((68, 256, 256), dtype=np.float32)
+ for i in range(68):
+ if pts[i, 0] > 0:
+ heatmaps[i] = draw_gaussian(
+ heatmaps[i], pts[i], 2)
+ heatmaps = torch.from_numpy(
+ heatmaps).unsqueeze_(0)
+
+ heatmaps = heatmaps.to(self.device)
+ depth_pred = self.depth_prediciton_net(
+ torch.cat((inp, heatmaps), 1)).data.cpu().view(68, 1)
+ pts_img = torch.cat(
+ (pts_img, depth_pred * (1.0 / (256.0 / (200.0 * scale)))), 1)
+
+ landmarks.append(pts_img.numpy())
+
+ return landmarks
+
+ def get_landmarks_from_directory(self, path, extensions=['.jpg', '.png'], recursive=True, show_progress_bar=True):
+ detected_faces = self.face_detector.detect_from_directory(path, extensions, recursive, show_progress_bar)
+
+ predictions = {}
+ for image_path, bounding_boxes in detected_faces.items():
+ image = io.imread(image_path)
+ preds = self.get_landmarks_from_image(image, bounding_boxes)
+ predictions[image_path] = preds
+
+ return predictions
+
+ @staticmethod
+ def remove_models(self):
+ base_path = os.path.join(appdata_dir('face_alignment'), "data")
+ for data_model in os.listdir(base_path):
+ file_path = os.path.join(base_path, data_model)
+ try:
+ if os.path.isfile(file_path):
+ print('Removing ' + data_model + ' ...')
+ os.unlink(file_path)
+ except Exception as e:
+ print(e)
diff --git a/FaceLandmarkDetection/face_alignment/detection/__init__.py b/FaceLandmarkDetection/face_alignment/detection/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a6b0402dae864a3cc5dc2a90a412fd842a0efc7
--- /dev/null
+++ b/FaceLandmarkDetection/face_alignment/detection/__init__.py
@@ -0,0 +1 @@
+from .core import FaceDetector
\ No newline at end of file
diff --git a/FaceLandmarkDetection/face_alignment/detection/__pycache__/__init__.cpython-36.pyc b/FaceLandmarkDetection/face_alignment/detection/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..09ea4818d422fd3267003f907aad6a7b79fd829e
Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/__pycache__/__init__.cpython-36.pyc differ
diff --git a/FaceLandmarkDetection/face_alignment/detection/__pycache__/__init__.cpython-37.pyc b/FaceLandmarkDetection/face_alignment/detection/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..605c422babf01756a23b39a4b869b24896a17586
Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/__pycache__/__init__.cpython-37.pyc differ
diff --git a/FaceLandmarkDetection/face_alignment/detection/__pycache__/core.cpython-36.pyc b/FaceLandmarkDetection/face_alignment/detection/__pycache__/core.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..93d2a9d3998cf9c2facbbf5b6731d067c7193967
Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/__pycache__/core.cpython-36.pyc differ
diff --git a/FaceLandmarkDetection/face_alignment/detection/__pycache__/core.cpython-37.pyc b/FaceLandmarkDetection/face_alignment/detection/__pycache__/core.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..95657db16b78133d1e8f86160cb96fe371b5ba16
Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/__pycache__/core.cpython-37.pyc differ
diff --git a/FaceLandmarkDetection/face_alignment/detection/core.py b/FaceLandmarkDetection/face_alignment/detection/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..47a3f9d5fcf461843b2aafb75031f06be591c7dd
--- /dev/null
+++ b/FaceLandmarkDetection/face_alignment/detection/core.py
@@ -0,0 +1,131 @@
+import logging
+import glob
+from tqdm import tqdm
+import numpy as np
+import torch
+import cv2
+from skimage import io
+
+
+class FaceDetector(object):
+ """An abstract class representing a face detector.
+
+ Any other face detection implementation must subclass it. All subclasses
+ must implement ``detect_from_image``, that return a list of detected
+ bounding boxes. Optionally, for speed considerations detect from path is
+ recommended.
+ """
+
+ def __init__(self, device, verbose):
+ self.device = device
+ self.verbose = verbose
+
+ if verbose:
+ if 'cpu' in device:
+ logger = logging.getLogger(__name__)
+ logger.warning("Detection running on CPU, this may be potentially slow.")
+
+ if 'cpu' not in device and 'cuda' not in device:
+ if verbose:
+ logger.error("Expected values for device are: {cpu, cuda} but got: %s", device)
+ raise ValueError
+
+ def detect_from_image(self, tensor_or_path):
+ """Detects faces in a given image.
+
+ This function detects the faces present in a provided BGR(usually)
+ image. The input can be either the image itself or the path to it.
+
+ Arguments:
+ tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
+ to an image or the image itself.
+
+ Example::
+
+ >>> path_to_image = 'data/image_01.jpg'
+ ... detected_faces = detect_from_image(path_to_image)
+ [A list of bounding boxes (x1, y1, x2, y2)]
+ >>> image = cv2.imread(path_to_image)
+ ... detected_faces = detect_from_image(image)
+ [A list of bounding boxes (x1, y1, x2, y2)]
+
+ """
+ raise NotImplementedError
+
+ def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
+ """Detects faces from all the images present in a given directory.
+
+ Arguments:
+ path {string} -- a string containing a path that points to the folder containing the images
+
+ Keyword Arguments:
+ extensions {list} -- list of string containing the extensions to be
+ consider in the following format: ``.extension_name`` (default:
+ {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
+ folder recursively (default: {False}) show_progress_bar {bool} --
+ display a progressbar (default: {True})
+
+ Example:
+ >>> directory = 'data'
+ ... detected_faces = detect_from_directory(directory)
+ {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}
+
+ """
+ if self.verbose:
+ logger = logging.getLogger(__name__)
+
+ if len(extensions) == 0:
+ if self.verbose:
+ logger.error("Expected at list one extension, but none was received.")
+ raise ValueError
+
+ if self.verbose:
+ logger.info("Constructing the list of images.")
+ additional_pattern = '/**/*' if recursive else '/*'
+ files = []
+ for extension in extensions:
+ files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))
+
+ if self.verbose:
+ logger.info("Finished searching for images. %s images found", len(files))
+ logger.info("Preparing to run the detection.")
+
+ predictions = {}
+ for image_path in tqdm(files, disable=not show_progress_bar):
+ if self.verbose:
+ logger.info("Running the face detector on image: %s", image_path)
+ predictions[image_path] = self.detect_from_image(image_path)
+
+ if self.verbose:
+ logger.info("The detector was successfully run on all %s images", len(files))
+
+ return predictions
+
+ @property
+ def reference_scale(self):
+ raise NotImplementedError
+
+ @property
+ def reference_x_shift(self):
+ raise NotImplementedError
+
+ @property
+ def reference_y_shift(self):
+ raise NotImplementedError
+
+ @staticmethod
+ def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):
+ """Convert path (represented as a string) or torch.tensor to a numpy.ndarray
+
+ Arguments:
+ tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
+ """
+ if isinstance(tensor_or_path, str):
+ return cv2.imread(tensor_or_path) if not rgb else io.imread(tensor_or_path)
+ elif torch.is_tensor(tensor_or_path):
+ # Call cpu in case its coming from cuda
+ return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy()
+ elif isinstance(tensor_or_path, np.ndarray):
+ return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path
+ else:
+ raise TypeError
diff --git a/FaceLandmarkDetection/face_alignment/detection/dlib/__init__.py b/FaceLandmarkDetection/face_alignment/detection/dlib/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8e5ee5fd5b6dd145f3e3ec65c02d2d8befaef59
--- /dev/null
+++ b/FaceLandmarkDetection/face_alignment/detection/dlib/__init__.py
@@ -0,0 +1 @@
+from .dlib_detector import DlibDetector as FaceDetector
\ No newline at end of file
diff --git a/FaceLandmarkDetection/face_alignment/detection/dlib/dlib_detector.py b/FaceLandmarkDetection/face_alignment/detection/dlib/dlib_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cc8368aef0203753fad4637b4e3981281977fe6
--- /dev/null
+++ b/FaceLandmarkDetection/face_alignment/detection/dlib/dlib_detector.py
@@ -0,0 +1,68 @@
+import os
+import cv2
+import dlib
+
+try:
+ import urllib.request as request_file
+except BaseException:
+ import urllib as request_file
+
+from ..core import FaceDetector
+from ...utils import appdata_dir
+
+
+class DlibDetector(FaceDetector):
+ def __init__(self, device, path_to_detector=None, verbose=False):
+ super().__init__(device, verbose)
+
+ print('Warning: this detector is deprecated. Please use a different one, i.e.: S3FD.')
+ base_path = os.path.join(appdata_dir('face_alignment'), "data")
+
+ # Initialise the face detector
+ if 'cuda' in device:
+ if path_to_detector is None:
+ path_to_detector = os.path.join(
+ base_path, "mmod_human_face_detector.dat")
+
+ if not os.path.isfile(path_to_detector):
+ print("Downloading the face detection CNN. Please wait...")
+
+ path_to_temp_detector = os.path.join(
+ base_path, "mmod_human_face_detector.dat.download")
+
+ if os.path.isfile(path_to_temp_detector):
+ os.remove(os.path.join(path_to_temp_detector))
+
+ request_file.urlretrieve(
+ "https://www.adrianbulat.com/downloads/dlib/mmod_human_face_detector.dat",
+ os.path.join(path_to_temp_detector))
+
+ os.rename(os.path.join(path_to_temp_detector), os.path.join(path_to_detector))
+
+ self.face_detector = dlib.cnn_face_detection_model_v1(path_to_detector)
+ else:
+ self.face_detector = dlib.get_frontal_face_detector()
+
+ def detect_from_image(self, tensor_or_path):
+ image = self.tensor_or_path_to_ndarray(tensor_or_path, rgb=False)
+
+ detected_faces = self.face_detector(cv2.cvtColor(image, cv2.COLOR_BGR2GRAY))
+
+ if 'cuda' not in self.device:
+ detected_faces = [[d.left(), d.top(), d.right(), d.bottom()] for d in detected_faces]
+ else:
+ detected_faces = [[d.rect.left(), d.rect.top(), d.rect.right(), d.rect.bottom()] for d in detected_faces]
+
+ return detected_faces
+
+ @property
+ def reference_scale(self):
+ return 195
+
+ @property
+ def reference_x_shift(self):
+ return 0
+
+ @property
+ def reference_y_shift(self):
+ return 0
diff --git a/FaceLandmarkDetection/face_alignment/detection/folder/__init__.py b/FaceLandmarkDetection/face_alignment/detection/folder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a9128ee0e13eb6f0058dbd480046aee50be336f
--- /dev/null
+++ b/FaceLandmarkDetection/face_alignment/detection/folder/__init__.py
@@ -0,0 +1 @@
+from .folder_detector import FolderDetector as FaceDetector
\ No newline at end of file
diff --git a/FaceLandmarkDetection/face_alignment/detection/folder/folder_detector.py b/FaceLandmarkDetection/face_alignment/detection/folder/folder_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..add19fa4e4b933619b75d1dce441e3071b235191
--- /dev/null
+++ b/FaceLandmarkDetection/face_alignment/detection/folder/folder_detector.py
@@ -0,0 +1,53 @@
+import os
+import numpy as np
+import torch
+
+from ..core import FaceDetector
+
+
+class FolderDetector(FaceDetector):
+ '''This is a simple helper module that assumes the faces were detected already
+ (either previously or are provided as ground truth).
+
+ The class expects to find the bounding boxes in the same format used by
+ the rest of face detectors, mainly ``list[(x1,y1,x2,y2),...]``.
+ For each image the detector will search for a file with the same name and with one of the
+ following extensions: .npy, .t7 or .pth
+
+ '''
+
+ def __init__(self, device, path_to_detector=None, verbose=False):
+ super(FolderDetector, self).__init__(device, verbose)
+
+ def detect_from_image(self, tensor_or_path):
+ # Only strings supported
+ if not isinstance(tensor_or_path, str):
+ raise ValueError
+
+ base_name = os.path.splitext(tensor_or_path)[0]
+
+ if os.path.isfile(base_name + '.npy'):
+ detected_faces = np.load(base_name + '.npy')
+ elif os.path.isfile(base_name + '.t7'):
+ detected_faces = torch.load(base_name + '.t7')
+ elif os.path.isfile(base_name + '.pth'):
+ detected_faces = torch.load(base_name + '.pth')
+ else:
+ raise FileNotFoundError
+
+ if not isinstance(detected_faces, list):
+ raise TypeError
+
+ return detected_faces
+
+ @property
+ def reference_scale(self):
+ return 195
+
+ @property
+ def reference_x_shift(self):
+ return 0
+
+ @property
+ def reference_y_shift(self):
+ return 0
diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/__init__.py b/FaceLandmarkDetection/face_alignment/detection/sfd/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a63ecd45658f22e66c171ada751fb33764d4559
--- /dev/null
+++ b/FaceLandmarkDetection/face_alignment/detection/sfd/__init__.py
@@ -0,0 +1 @@
+from .sfd_detector import SFDDetector as FaceDetector
\ No newline at end of file
diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/__init__.cpython-36.pyc b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..223f1c8c50f26c44013f12edbb297b3f0822d817
Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/__init__.cpython-36.pyc differ
diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/__init__.cpython-37.pyc b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..71478176e62df158a6458d825eedf176cff1cb16
Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/__init__.cpython-37.pyc differ
diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/bbox.cpython-36.pyc b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/bbox.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..06a77197d502544882d7676b71e89525fab5eaab
Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/bbox.cpython-36.pyc differ
diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/bbox.cpython-37.pyc b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/bbox.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..315d9b9b15d5b094b864f7a4a40c863c5276af0e
Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/bbox.cpython-37.pyc differ
diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/detect.cpython-36.pyc b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/detect.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..37daedbfb169a7bf35c91a6d9abbf44d22c4c315
Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/detect.cpython-36.pyc differ
diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/detect.cpython-37.pyc b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/detect.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a5954c1300f601d431745d74ef3cfded16b6980d
Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/detect.cpython-37.pyc differ
diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/net_s3fd.cpython-36.pyc b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/net_s3fd.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0984c1e8a9c2bb8b78fc70c10f6fbc1f6cafc1aa
Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/net_s3fd.cpython-36.pyc differ
diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/net_s3fd.cpython-37.pyc b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/net_s3fd.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bff6f907dbf2c33fdaa0062748bb28a76e8b65ad
Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/net_s3fd.cpython-37.pyc differ
diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/sfd_detector.cpython-36.pyc b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/sfd_detector.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b1b5f8b065f4da5d65f432a774e5b08620da4723
Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/sfd_detector.cpython-36.pyc differ
diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/sfd_detector.cpython-37.pyc b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/sfd_detector.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..33f1ccaeb437d82a166b0fb1603628525796c02f
Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/sfd_detector.cpython-37.pyc differ
diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/bbox.py b/FaceLandmarkDetection/face_alignment/detection/sfd/bbox.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca662365044608b992e7f15794bd2518b0781d7e
--- /dev/null
+++ b/FaceLandmarkDetection/face_alignment/detection/sfd/bbox.py
@@ -0,0 +1,109 @@
+from __future__ import print_function
+import os
+import sys
+import cv2
+import random
+import datetime
+import time
+import math
+import argparse
+import numpy as np
+import torch
+
+try:
+ from iou import IOU
+except BaseException:
+ # IOU cython speedup 10x
+ def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
+ sa = abs((ax2 - ax1) * (ay2 - ay1))
+ sb = abs((bx2 - bx1) * (by2 - by1))
+ x1, y1 = max(ax1, bx1), max(ay1, by1)
+ x2, y2 = min(ax2, bx2), min(ay2, by2)
+ w = x2 - x1
+ h = y2 - y1
+ if w < 0 or h < 0:
+ return 0.0
+ else:
+ return 1.0 * w * h / (sa + sb - w * h)
+
+
+def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
+ xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1
+ dx, dy = (xc - axc) / aww, (yc - ayc) / ahh
+ dw, dh = math.log(ww / aww), math.log(hh / ahh)
+ return dx, dy, dw, dh
+
+
+def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
+ xc, yc = dx * aww + axc, dy * ahh + ayc
+ ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh
+ x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2
+ return x1, y1, x2, y2
+
+
+def nms(dets, thresh):
+ if 0 == len(dets):
+ return []
+ x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+ order = scores.argsort()[::-1]
+
+ keep = []
+ while order.size > 0:
+ i = order[0]
+ keep.append(i)
+ xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
+ xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])
+
+ w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
+ ovr = w * h / (areas[i] + areas[order[1:]] - w * h)
+
+ inds = np.where(ovr <= thresh)[0]
+ order = order[inds + 1]
+
+ return keep
+
+
+def encode(matched, priors, variances):
+ """Encode the variances from the priorbox layers into the ground truth boxes
+ we have matched (based on jaccard overlap) with the prior boxes.
+ Args:
+ matched: (tensor) Coords of ground truth for each prior in point-form
+ Shape: [num_priors, 4].
+ priors: (tensor) Prior boxes in center-offset form
+ Shape: [num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ encoded boxes (tensor), Shape: [num_priors, 4]
+ """
+
+ # dist b/t match center and prior's center
+ g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
+ # encode variance
+ g_cxcy /= (variances[0] * priors[:, 2:])
+ # match wh / prior wh
+ g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
+ g_wh = torch.log(g_wh) / variances[1]
+ # return target for smooth_l1_loss
+ return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
+
+
+def decode(loc, priors, variances):
+ """Decode locations from predictions using priors to undo
+ the encoding we did for offset regression at train time.
+ Args:
+ loc (tensor): location predictions for loc layers,
+ Shape: [num_priors,4]
+ priors (tensor): Prior boxes in center-offset form.
+ Shape: [num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ decoded bounding box predictions
+ """
+
+ boxes = torch.cat((
+ priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
+ priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
+ boxes[:, :2] -= boxes[:, 2:] / 2
+ boxes[:, 2:] += boxes[:, :2]
+ return boxes
diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/detect.py b/FaceLandmarkDetection/face_alignment/detection/sfd/detect.py
new file mode 100644
index 0000000000000000000000000000000000000000..84d120de15145a830938a2edcf9cf8d29bc59f72
--- /dev/null
+++ b/FaceLandmarkDetection/face_alignment/detection/sfd/detect.py
@@ -0,0 +1,75 @@
+import torch
+import torch.nn.functional as F
+
+import os
+import sys
+import cv2
+import random
+import datetime
+import math
+import argparse
+import numpy as np
+
+import scipy.io as sio
+import zipfile
+from .net_s3fd import s3fd
+from .bbox import *
+
+
+def detect(net, img, device):
+ img = img - np.array([104, 117, 123])
+ img = img.transpose(2, 0, 1)
+ img = img.reshape((1,) + img.shape)
+
+ if 'cuda' in device:
+ torch.backends.cudnn.benchmark = True
+
+ img = torch.from_numpy(img).float().to(device)
+ BB, CC, HH, WW = img.size()
+ with torch.no_grad():
+ olist = net(img)
+
+ bboxlist = []
+ for i in range(len(olist) // 2):
+ olist[i * 2] = F.softmax(olist[i * 2], dim=1)
+ olist = [oelem.data.cpu() for oelem in olist]
+ for i in range(len(olist) // 2):
+ ocls, oreg = olist[i * 2], olist[i * 2 + 1]
+ FB, FC, FH, FW = ocls.size() # feature map size
+ stride = 2**(i + 2) # 4,8,16,32,64,128
+ anchor = stride * 4
+ poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
+ for Iindex, hindex, windex in poss:
+ axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
+ score = ocls[0, 1, hindex, windex]
+ loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
+ priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
+ variances = [0.1, 0.2]
+ box = decode(loc, priors, variances)
+ x1, y1, x2, y2 = box[0] * 1.0
+ # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
+ bboxlist.append([x1, y1, x2, y2, score])
+ bboxlist = np.array(bboxlist)
+ if 0 == len(bboxlist):
+ bboxlist = np.zeros((1, 5))
+
+ return bboxlist
+
+
+def flip_detect(net, img, device):
+ img = cv2.flip(img, 1)
+ b = detect(net, img, device)
+
+ bboxlist = np.zeros(b.shape)
+ bboxlist[:, 0] = img.shape[1] - b[:, 2]
+ bboxlist[:, 1] = b[:, 1]
+ bboxlist[:, 2] = img.shape[1] - b[:, 0]
+ bboxlist[:, 3] = b[:, 3]
+ bboxlist[:, 4] = b[:, 4]
+ return bboxlist
+
+
+def pts_to_bb(pts):
+ min_x, min_y = np.min(pts, axis=0)
+ max_x, max_y = np.max(pts, axis=0)
+ return np.array([min_x, min_y, max_x, max_y])
diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/net_s3fd.py b/FaceLandmarkDetection/face_alignment/detection/sfd/net_s3fd.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc64313c277ab594d0257585c70f147606693452
--- /dev/null
+++ b/FaceLandmarkDetection/face_alignment/detection/sfd/net_s3fd.py
@@ -0,0 +1,129 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class L2Norm(nn.Module):
+ def __init__(self, n_channels, scale=1.0):
+ super(L2Norm, self).__init__()
+ self.n_channels = n_channels
+ self.scale = scale
+ self.eps = 1e-10
+ self.weight = nn.Parameter(torch.Tensor(self.n_channels))
+ self.weight.data *= 0.0
+ self.weight.data += self.scale
+
+ def forward(self, x):
+ norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
+ x = x / norm * self.weight.view(1, -1, 1, 1)
+ return x
+
+
+class s3fd(nn.Module):
+ def __init__(self):
+ super(s3fd, self).__init__()
+ self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
+ self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
+
+ self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
+ self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
+
+ self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
+ self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
+ self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
+
+ self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
+ self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
+ self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
+
+ self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
+ self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
+ self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
+
+ self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
+ self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
+
+ self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
+ self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
+
+ self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
+ self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
+
+ self.conv3_3_norm = L2Norm(256, scale=10)
+ self.conv4_3_norm = L2Norm(512, scale=8)
+ self.conv5_3_norm = L2Norm(512, scale=5)
+
+ self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
+ self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
+ self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
+ self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
+ self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
+ self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
+
+ self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
+ self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
+ self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
+ self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
+ self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
+ self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x):
+ h = F.relu(self.conv1_1(x))
+ h = F.relu(self.conv1_2(h))
+ h = F.max_pool2d(h, 2, 2)
+
+ h = F.relu(self.conv2_1(h))
+ h = F.relu(self.conv2_2(h))
+ h = F.max_pool2d(h, 2, 2)
+
+ h = F.relu(self.conv3_1(h))
+ h = F.relu(self.conv3_2(h))
+ h = F.relu(self.conv3_3(h))
+ f3_3 = h
+ h = F.max_pool2d(h, 2, 2)
+
+ h = F.relu(self.conv4_1(h))
+ h = F.relu(self.conv4_2(h))
+ h = F.relu(self.conv4_3(h))
+ f4_3 = h
+ h = F.max_pool2d(h, 2, 2)
+
+ h = F.relu(self.conv5_1(h))
+ h = F.relu(self.conv5_2(h))
+ h = F.relu(self.conv5_3(h))
+ f5_3 = h
+ h = F.max_pool2d(h, 2, 2)
+
+ h = F.relu(self.fc6(h))
+ h = F.relu(self.fc7(h))
+ ffc7 = h
+ h = F.relu(self.conv6_1(h))
+ h = F.relu(self.conv6_2(h))
+ f6_2 = h
+ h = F.relu(self.conv7_1(h))
+ h = F.relu(self.conv7_2(h))
+ f7_2 = h
+
+ f3_3 = self.conv3_3_norm(f3_3)
+ f4_3 = self.conv4_3_norm(f4_3)
+ f5_3 = self.conv5_3_norm(f5_3)
+
+ cls1 = self.conv3_3_norm_mbox_conf(f3_3)
+ reg1 = self.conv3_3_norm_mbox_loc(f3_3)
+ cls2 = self.conv4_3_norm_mbox_conf(f4_3)
+ reg2 = self.conv4_3_norm_mbox_loc(f4_3)
+ cls3 = self.conv5_3_norm_mbox_conf(f5_3)
+ reg3 = self.conv5_3_norm_mbox_loc(f5_3)
+ cls4 = self.fc7_mbox_conf(ffc7)
+ reg4 = self.fc7_mbox_loc(ffc7)
+ cls5 = self.conv6_2_mbox_conf(f6_2)
+ reg5 = self.conv6_2_mbox_loc(f6_2)
+ cls6 = self.conv7_2_mbox_conf(f7_2)
+ reg6 = self.conv7_2_mbox_loc(f7_2)
+
+ # max-out background label
+ chunk = torch.chunk(cls1, 4, 1)
+ bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
+ cls1 = torch.cat([bmax, chunk[3]], dim=1)
+
+ return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]
diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/sfd_detector.py b/FaceLandmarkDetection/face_alignment/detection/sfd/sfd_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..29c7768558b83dceea8f94e4d859b71dfa54cc85
--- /dev/null
+++ b/FaceLandmarkDetection/face_alignment/detection/sfd/sfd_detector.py
@@ -0,0 +1,51 @@
+import os
+import cv2
+from torch.utils.model_zoo import load_url
+
+from ..core import FaceDetector
+
+from .net_s3fd import s3fd
+from .bbox import *
+from .detect import *
+
+models_urls = {
+ 's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
+}
+
+
+class SFDDetector(FaceDetector):
+ def __init__(self, device, path_to_detector=None, verbose=False):
+ super(SFDDetector, self).__init__(device, verbose)
+
+ # Initialise the face detector
+ if path_to_detector is None:
+ model_weights = load_url(models_urls['s3fd'])
+ else:
+ model_weights = torch.load(path_to_detector)
+
+ self.face_detector = s3fd()
+ self.face_detector.load_state_dict(model_weights)
+ self.face_detector.to(device)
+ self.face_detector.eval()
+
+ def detect_from_image(self, tensor_or_path):
+ image = self.tensor_or_path_to_ndarray(tensor_or_path)
+
+ bboxlist = detect(self.face_detector, image, device=self.device)
+ keep = nms(bboxlist, 0.3)
+ bboxlist = bboxlist[keep, :]
+ bboxlist = [x for x in bboxlist if x[-1] > 0.5]
+
+ return bboxlist
+
+ @property
+ def reference_scale(self):
+ return 195
+
+ @property
+ def reference_x_shift(self):
+ return 0
+
+ @property
+ def reference_y_shift(self):
+ return 0
diff --git a/FaceLandmarkDetection/face_alignment/models.py b/FaceLandmarkDetection/face_alignment/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee2dde32bdf72c25a4600e48efa73ffc0d4a3893
--- /dev/null
+++ b/FaceLandmarkDetection/face_alignment/models.py
@@ -0,0 +1,261 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+
+
+def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
+ "3x3 convolution with padding"
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3,
+ stride=strd, padding=padding, bias=bias)
+
+
+class ConvBlock(nn.Module):
+ def __init__(self, in_planes, out_planes):
+ super(ConvBlock, self).__init__()
+ self.bn1 = nn.BatchNorm2d(in_planes)
+ self.conv1 = conv3x3(in_planes, int(out_planes / 2))
+ self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
+ self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
+ self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
+ self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
+
+ if in_planes != out_planes:
+ self.downsample = nn.Sequential(
+ nn.BatchNorm2d(in_planes),
+ nn.ReLU(True),
+ nn.Conv2d(in_planes, out_planes,
+ kernel_size=1, stride=1, bias=False),
+ )
+ else:
+ self.downsample = None
+
+ def forward(self, x):
+ residual = x
+
+ out1 = self.bn1(x)
+ out1 = F.relu(out1, True)
+ out1 = self.conv1(out1)
+
+ out2 = self.bn2(out1)
+ out2 = F.relu(out2, True)
+ out2 = self.conv2(out2)
+
+ out3 = self.bn3(out2)
+ out3 = F.relu(out3, True)
+ out3 = self.conv3(out3)
+
+ out3 = torch.cat((out1, out2, out3), 1)
+
+ if self.downsample is not None:
+ residual = self.downsample(residual)
+
+ out3 += residual
+
+ return out3
+
+
+class Bottleneck(nn.Module):
+
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * 4)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class HourGlass(nn.Module):
+ def __init__(self, num_modules, depth, num_features):
+ super(HourGlass, self).__init__()
+ self.num_modules = num_modules
+ self.depth = depth
+ self.features = num_features
+
+ self._generate_network(self.depth)
+
+ def _generate_network(self, level):
+ self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))
+
+ self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))
+
+ if level > 1:
+ self._generate_network(level - 1)
+ else:
+ self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))
+
+ self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))
+
+ def _forward(self, level, inp):
+ # Upper branch
+ up1 = inp
+ up1 = self._modules['b1_' + str(level)](up1)
+
+ # Lower branch
+ low1 = F.avg_pool2d(inp, 2, stride=2)
+ low1 = self._modules['b2_' + str(level)](low1)
+
+ if level > 1:
+ low2 = self._forward(level - 1, low1)
+ else:
+ low2 = low1
+ low2 = self._modules['b2_plus_' + str(level)](low2)
+
+ low3 = low2
+ low3 = self._modules['b3_' + str(level)](low3)
+
+ up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
+
+ return up1 + up2
+
+ def forward(self, x):
+ return self._forward(self.depth, x)
+
+
+class FAN(nn.Module):
+
+ def __init__(self, num_modules=1):
+ super(FAN, self).__init__()
+ self.num_modules = num_modules
+
+ # Base part
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.conv2 = ConvBlock(64, 128)
+ self.conv3 = ConvBlock(128, 128)
+ self.conv4 = ConvBlock(128, 256)
+
+ # Stacking part
+ for hg_module in range(self.num_modules):
+ self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
+ self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
+ self.add_module('conv_last' + str(hg_module),
+ nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
+ self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
+ self.add_module('l' + str(hg_module), nn.Conv2d(256,
+ 68, kernel_size=1, stride=1, padding=0))
+
+ if hg_module < self.num_modules - 1:
+ self.add_module(
+ 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
+ self.add_module('al' + str(hg_module), nn.Conv2d(68,
+ 256, kernel_size=1, stride=1, padding=0))
+
+ def forward(self, x):
+ x = F.relu(self.bn1(self.conv1(x)), True)
+ x = F.avg_pool2d(self.conv2(x), 2, stride=2)
+ x = self.conv3(x)
+ x = self.conv4(x)
+
+ previous = x
+
+ outputs = []
+ for i in range(self.num_modules):
+ hg = self._modules['m' + str(i)](previous)
+
+ ll = hg
+ ll = self._modules['top_m_' + str(i)](ll)
+
+ ll = F.relu(self._modules['bn_end' + str(i)]
+ (self._modules['conv_last' + str(i)](ll)), True)
+
+ # Predict heatmaps
+ tmp_out = self._modules['l' + str(i)](ll)
+ outputs.append(tmp_out)
+
+ if i < self.num_modules - 1:
+ ll = self._modules['bl' + str(i)](ll)
+ tmp_out_ = self._modules['al' + str(i)](tmp_out)
+ previous = previous + ll + tmp_out_
+
+ return outputs
+
+
+class ResNetDepth(nn.Module):
+
+ def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
+ self.inplanes = 64
+ super(ResNetDepth, self).__init__()
+ self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+ self.avgpool = nn.AvgPool2d(7)
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.avgpool(x)
+ x = x.view(x.size(0), -1)
+ x = self.fc(x)
+
+ return x
diff --git a/FaceLandmarkDetection/face_alignment/utils.py b/FaceLandmarkDetection/face_alignment/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..619570d580d65c84ba60ee043ad381a850f1ab26
--- /dev/null
+++ b/FaceLandmarkDetection/face_alignment/utils.py
@@ -0,0 +1,274 @@
+from __future__ import print_function
+import os
+import sys
+import time
+import torch
+import math
+import numpy as np
+import cv2
+
+
+def _gaussian(
+ size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
+ height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
+ mean_vert=0.5):
+ # handle some defaults
+ if width is None:
+ width = size
+ if height is None:
+ height = size
+ if sigma_horz is None:
+ sigma_horz = sigma
+ if sigma_vert is None:
+ sigma_vert = sigma
+ center_x = mean_horz * width + 0.5
+ center_y = mean_vert * height + 0.5
+ gauss = np.empty((height, width), dtype=np.float32)
+ # generate kernel
+ for i in range(height):
+ for j in range(width):
+ gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
+ sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
+ if normalize:
+ gauss = gauss / np.sum(gauss)
+ return gauss
+
+
+def draw_gaussian(image, point, sigma):
+ # Check if the gaussian is inside
+ ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
+ br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
+ if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
+ return image
+ size = 6 * sigma + 1
+ g = _gaussian(size)
+ g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
+ g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
+ img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
+ img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
+ assert (g_x[0] > 0 and g_y[1] > 0)
+ image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
+ ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
+ image[image > 1] = 1
+ return image
+
+
+def transform(point, center, scale, resolution, invert=False):
+ """Generate and affine transformation matrix.
+
+ Given a set of points, a center, a scale and a targer resolution, the
+ function generates and affine transformation matrix. If invert is ``True``
+ it will produce the inverse transformation.
+
+ Arguments:
+ point {torch.tensor} -- the input 2D point
+ center {torch.tensor or numpy.array} -- the center around which to perform the transformations
+ scale {float} -- the scale of the face/object
+ resolution {float} -- the output resolution
+
+ Keyword Arguments:
+ invert {bool} -- define wherever the function should produce the direct or the
+ inverse transformation matrix (default: {False})
+ """
+ _pt = torch.ones(3)
+ _pt[0] = point[0]
+ _pt[1] = point[1]
+
+ h = 200.0 * scale
+ t = torch.eye(3)
+ t[0, 0] = resolution / h
+ t[1, 1] = resolution / h
+ t[0, 2] = resolution * (-center[0] / h + 0.5)
+ t[1, 2] = resolution * (-center[1] / h + 0.5)
+
+ if invert:
+ t = torch.inverse(t)
+
+ new_point = (torch.matmul(t, _pt))[0:2]
+
+ return new_point.int()
+
+
+def crop(image, center, scale, resolution=256.0):
+ """Center crops an image or set of heatmaps
+
+ Arguments:
+ image {numpy.array} -- an rgb image
+ center {numpy.array} -- the center of the object, usually the same as of the bounding box
+ scale {float} -- scale of the face
+
+ Keyword Arguments:
+ resolution {float} -- the size of the output cropped image (default: {256.0})
+
+ Returns:
+ [type] -- [description]
+ """ # Crop around the center point
+ """ Crops the image around the center. Input is expected to be an np.ndarray """
+ ul = transform([1, 1], center, scale, resolution, True)
+ br = transform([resolution, resolution], center, scale, resolution, True)
+ # pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
+ if image.ndim > 2:
+ newDim = np.array([br[1] - ul[1], br[0] - ul[0],
+ image.shape[2]], dtype=np.int32)
+ newImg = np.zeros(newDim, dtype=np.uint8)
+ else:
+ newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
+ newImg = np.zeros(newDim, dtype=np.uint8)
+ ht = image.shape[0]
+ wd = image.shape[1]
+ newX = np.array(
+ [max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
+ newY = np.array(
+ [max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
+ oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
+ oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
+ newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
+ ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
+ newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
+ interpolation=cv2.INTER_LINEAR)
+ return newImg
+
+
+def get_preds_fromhm(hm, center=None, scale=None):
+ """Obtain (x,y) coordinates given a set of N heatmaps. If the center
+ and the scale is provided the function will return the points also in
+ the original coordinate frame.
+
+ Arguments:
+ hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
+
+ Keyword Arguments:
+ center {torch.tensor} -- the center of the bounding box (default: {None})
+ scale {float} -- face scale (default: {None})
+ """
+ max, idx = torch.max(
+ hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
+ idx += 1
+ preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
+ preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
+ preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
+
+ for i in range(preds.size(0)):
+ for j in range(preds.size(1)):
+ hm_ = hm[i, j, :]
+ pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
+ if pX > 0 and pX < 63 and pY > 0 and pY < 63:
+ diff = torch.FloatTensor(
+ [hm_[pY, pX + 1] - hm_[pY, pX - 1],
+ hm_[pY + 1, pX] - hm_[pY - 1, pX]])
+ preds[i, j].add_(diff.sign_().mul_(.25))
+
+ preds.add_(-.5)
+
+ preds_orig = torch.zeros(preds.size())
+ if center is not None and scale is not None:
+ for i in range(hm.size(0)):
+ for j in range(hm.size(1)):
+ preds_orig[i, j] = transform(
+ preds[i, j], center, scale, hm.size(2), True)
+
+ return preds, preds_orig
+
+
+def shuffle_lr(parts, pairs=None):
+ """Shuffle the points left-right according to the axis of symmetry
+ of the object.
+
+ Arguments:
+ parts {torch.tensor} -- a 3D or 4D object containing the
+ heatmaps.
+
+ Keyword Arguments:
+ pairs {list of integers} -- [order of the flipped points] (default: {None})
+ """
+ if pairs is None:
+ pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
+ 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
+ 34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
+ 40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
+ 62, 61, 60, 67, 66, 65]
+ if parts.ndimension() == 3:
+ parts = parts[pairs, ...]
+ else:
+ parts = parts[:, pairs, ...]
+
+ return parts
+
+
+def flip(tensor, is_label=False):
+ """Flip an image or a set of heatmaps left-right
+
+ Arguments:
+ tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]
+
+ Keyword Arguments:
+ is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
+ """
+ if not torch.is_tensor(tensor):
+ tensor = torch.from_numpy(tensor)
+
+ if is_label:
+ tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
+ else:
+ tensor = tensor.flip(tensor.ndimension() - 1)
+
+ return tensor
+
+# From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
+
+
+def appdata_dir(appname=None, roaming=False):
+ """ appdata_dir(appname=None, roaming=False)
+
+ Get the path to the application directory, where applications are allowed
+ to write user specific files (e.g. configurations). For non-user specific
+ data, consider using common_appdata_dir().
+ If appname is given, a subdir is appended (and created if necessary).
+ If roaming is True, will prefer a roaming directory (Windows Vista/7).
+ """
+
+ # Define default user directory
+ userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
+ if userDir is None:
+ userDir = os.path.expanduser('~')
+ if not os.path.isdir(userDir): # pragma: no cover
+ userDir = '/var/tmp' # issue #54
+
+ # Get system app data dir
+ path = None
+ if sys.platform.startswith('win'):
+ path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
+ path = (path2 or path1) if roaming else (path1 or path2)
+ elif sys.platform.startswith('darwin'):
+ path = os.path.join(userDir, 'Library', 'Application Support')
+ # On Linux and as fallback
+ if not (path and os.path.isdir(path)):
+ path = userDir
+
+ # Maybe we should store things local to the executable (in case of a
+ # portable distro or a frozen application that wants to be portable)
+ prefix = sys.prefix
+ if getattr(sys, 'frozen', None):
+ prefix = os.path.abspath(os.path.dirname(sys.executable))
+ for reldir in ('settings', '../settings'):
+ localpath = os.path.abspath(os.path.join(prefix, reldir))
+ if os.path.isdir(localpath): # pragma: no cover
+ try:
+ open(os.path.join(localpath, 'test.write'), 'wb').close()
+ os.remove(os.path.join(localpath, 'test.write'))
+ except IOError:
+ pass # We cannot write in this directory
+ else:
+ path = localpath
+ break
+
+ # Get path specific for this app
+ if appname:
+ if path == userDir:
+ appname = '.' + appname.lstrip('.') # Make it a hidden directory
+ path = os.path.join(path, appname)
+ if not os.path.isdir(path): # pragma: no cover
+ os.mkdir(path)
+
+ # Done
+ return path
diff --git a/FaceLandmarkDetection/get_face_landmark.py b/FaceLandmarkDetection/get_face_landmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ee59679395824add92f2539a5a2bec3443a0ec9
--- /dev/null
+++ b/FaceLandmarkDetection/get_face_landmark.py
@@ -0,0 +1,46 @@
+#!/usr/bin/python #encoding:utf-8
+import torch
+import pickle
+import numpy as np
+import matplotlib.pyplot as plt
+from PIL import Image
+
+import cv2
+import os
+import face_alignment
+from skimage import io, transform
+
+fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D,device='cuda:0', flip_input=False)
+
+
+Nums = 0
+FilePath = '../TestData/RealVgg/Imgs'
+SavePath = '../TestData/RealVgg/Imgs/Landmarks'
+if not os.path.exists(SavePath):
+ os.mkdir(SavePath)
+
+ImgNames = os.listdir(FilePath)
+ImgNames.sort()
+
+for i,name in enumerate(ImgNames):
+ print((i,name))
+
+ imgs = io.imread(os.path.join(FilePath,name))
+
+ imgO = imgs
+ try:
+ PredsAll = fa.get_landmarks(imgO)
+ except:
+ print('#########No face')
+ continue
+ if PredsAll is None:
+ print('#########No face2')
+ continue
+ if len(PredsAll)!=1:
+ print('#########too many face')
+ continue
+ preds = PredsAll[-1]
+ AddLength = np.sqrt(np.sum(np.power(preds[27][0:2]-preds[33][0:2],2)))
+ SaveName = name+'.txt'
+
+ np.savetxt(os.path.join(SavePath,SaveName),preds[:,0:2],fmt='%.3f')
diff --git a/FaceLandmarkDetection/setup.cfg b/FaceLandmarkDetection/setup.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..1d5d127316734bcda878fffbbc24c56404911739
--- /dev/null
+++ b/FaceLandmarkDetection/setup.cfg
@@ -0,0 +1,32 @@
+[bumpversion]
+current_version = 1.0.1
+commit = True
+tag = True
+
+[bumpversion:file:setup.py]
+search = version='{current_version}'
+replace = version='{new_version}'
+
+[bumpversion:file:face_alignment/__init__.py]
+search = __version__ = '{current_version}'
+replace = __version__ = '{new_version}'
+
+[metadata]
+description-file = README.md
+
+[bdist_wheel]
+universal = 1
+
+[flake8]
+exclude =
+ .github,
+ examples,
+ docs,
+ .tox,
+ bin,
+ dist,
+ tools,
+ *.egg-info,
+ __init__.py,
+ *.yml
+max-line-length = 160
\ No newline at end of file
diff --git a/FaceLandmarkDetection/setup.py b/FaceLandmarkDetection/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa1a0be9bbdbc3d8e8a07e07ab11b88b5b865a30
--- /dev/null
+++ b/FaceLandmarkDetection/setup.py
@@ -0,0 +1,83 @@
+import io
+import os
+from os import path
+import re
+from setuptools import setup, find_packages
+# To use consisten encodings
+from codecs import open
+
+# Function from: https://github.com/pytorch/vision/blob/master/setup.py
+
+
+def read(*names, **kwargs):
+ with io.open(
+ os.path.join(os.path.dirname(__file__), *names),
+ encoding=kwargs.get("encoding", "utf8")
+ ) as fp:
+ return fp.read()
+
+# Function from: https://github.com/pytorch/vision/blob/master/setup.py
+
+
+def find_version(*file_paths):
+ version_file = read(*file_paths)
+ version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
+ version_file, re.M)
+ if version_match:
+ return version_match.group(1)
+ raise RuntimeError("Unable to find version string.")
+
+here = path.abspath(path.dirname(__file__))
+
+# Get the long description from the README file
+with open(path.join(here, 'README.md'), encoding='utf-8') as readme_file:
+ long_description = readme_file.read()
+
+VERSION = find_version('face_alignment', '__init__.py')
+
+requirements = [
+ 'torch',
+ 'numpy',
+ 'scipy>=0.17',
+ 'scikit-image',
+ 'opencv-python',
+ 'tqdm',
+ 'enum34;python_version<"3.4"'
+]
+
+setup(
+ name='face_alignment',
+ version=VERSION,
+
+ description="Detector 2D or 3D face landmarks from Python",
+ long_description=long_description,
+ long_description_content_type="text/markdown",
+
+ # Author details
+ author="Adrian Bulat",
+ author_email="adrian.bulat@nottingham.ac.uk",
+ url="https://github.com/1adrianb/face-alignment",
+
+ # Package info
+ packages=find_packages(exclude=('test',)),
+
+ install_requires=requirements,
+ license='BSD',
+ zip_safe=True,
+
+ classifiers=[
+ 'Development Status :: 5 - Production/Stable',
+ 'Operating System :: OS Independent',
+ 'License :: OSI Approved :: BSD License',
+ 'Natural Language :: English',
+
+ # Supported python versions
+ 'Programming Language :: Python :: 2',
+ 'Programming Language :: Python :: 2.7',
+ 'Programming Language :: Python :: 3',
+ 'Programming Language :: Python :: 3.3',
+ 'Programming Language :: Python :: 3.4',
+ 'Programming Language :: Python :: 3.5',
+ 'Programming Language :: Python :: 3.6',
+ ],
+)
diff --git a/FaceLandmarkDetection/test/assets/aflw-test.jpg b/FaceLandmarkDetection/test/assets/aflw-test.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e798f64cdeda860e3b2a99e0d7e5ce8e52d7c0f4
Binary files /dev/null and b/FaceLandmarkDetection/test/assets/aflw-test.jpg differ
diff --git a/FaceLandmarkDetection/test/facealignment_test.py b/FaceLandmarkDetection/test/facealignment_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..85637f953ab580fd8756f3476afa02723a8e12c0
--- /dev/null
+++ b/FaceLandmarkDetection/test/facealignment_test.py
@@ -0,0 +1,11 @@
+import unittest
+import face_alignment
+
+
+class Tester(unittest.TestCase):
+ def test_predict_points(self):
+ fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, device='cpu')
+ fa.get_landmarks('test/assets/aflw-test.jpg')
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/FaceLandmarkDetection/test/smoke_test.py b/FaceLandmarkDetection/test/smoke_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..93a006d0793ac3aa1cdaedf57f45eac7b9672464
--- /dev/null
+++ b/FaceLandmarkDetection/test/smoke_test.py
@@ -0,0 +1,2 @@
+import torch
+import face_alignment
diff --git a/FaceLandmarkDetection/test/test_utils.py b/FaceLandmarkDetection/test/test_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..09c4d3eb0d7496d0900f3e32f14bfa8172a4252c
--- /dev/null
+++ b/FaceLandmarkDetection/test/test_utils.py
@@ -0,0 +1,36 @@
+import unittest
+from face_alignment.utils import *
+import numpy as np
+import torch
+
+
+class Tester(unittest.TestCase):
+ def test_flip_is_label(self):
+ # Generate the points
+ heatmaps = torch.from_numpy(np.random.randint(1, high=250, size=(68, 64, 64)).astype('float32'))
+
+ flipped_heatmaps = flip(flip(heatmaps.clone(), is_label=True), is_label=True)
+
+ assert np.allclose(heatmaps.numpy(), flipped_heatmaps.numpy())
+
+ def test_flip_is_image(self):
+ fake_image = torch.torch.rand(3, 256, 256)
+ fliped_fake_image = flip(flip(fake_image.clone()))
+
+ assert np.allclose(fake_image.numpy(), fliped_fake_image.numpy())
+
+ def test_getpreds(self):
+ pts = torch.from_numpy(np.random.randint(1, high=63, size=(68, 2)).astype('float32'))
+
+ heatmaps = np.zeros((68, 256, 256))
+ for i in range(68):
+ if pts[i, 0] > 0:
+ heatmaps[i] = draw_gaussian(heatmaps[i], pts[i], 2)
+ heatmaps = torch.from_numpy(np.expand_dims(heatmaps, axis=0))
+
+ preds, _ = get_preds_fromhm(heatmaps)
+
+ assert np.allclose(pts.numpy(), preds.numpy(), atol=5)
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/FaceLandmarkDetection/tox.ini b/FaceLandmarkDetection/tox.ini
new file mode 100644
index 0000000000000000000000000000000000000000..0a862cd793d75c6005fb4dcda05edc58044f86b3
--- /dev/null
+++ b/FaceLandmarkDetection/tox.ini
@@ -0,0 +1,3 @@
+[flake8]
+max-line-length = 120
+ignore = E305,E402,E721,F401,F403,F405,F821,F841,F999
\ No newline at end of file
diff --git a/Imgs/RealLR/n000056_0060_01.png b/Imgs/RealLR/n000056_0060_01.png
new file mode 100644
index 0000000000000000000000000000000000000000..005309a61a613bee4270ae4d05fb7320b1bde0be
Binary files /dev/null and b/Imgs/RealLR/n000056_0060_01.png differ
diff --git a/Imgs/RealLR/n000067_0228_01.png b/Imgs/RealLR/n000067_0228_01.png
new file mode 100644
index 0000000000000000000000000000000000000000..c0e4f322ec3f5db1554777e233237c5cdd0fb007
Binary files /dev/null and b/Imgs/RealLR/n000067_0228_01.png differ
diff --git a/Imgs/RealLR/n000184_0094_01.png b/Imgs/RealLR/n000184_0094_01.png
new file mode 100644
index 0000000000000000000000000000000000000000..3e2995de5706ed71a445d3519d2c214df503b9ae
Binary files /dev/null and b/Imgs/RealLR/n000184_0094_01.png differ
diff --git a/Imgs/RealLR/n000241_0132_04.png b/Imgs/RealLR/n000241_0132_04.png
new file mode 100644
index 0000000000000000000000000000000000000000..4f058089b0320d201f35736148e7561ad5a625d1
Binary files /dev/null and b/Imgs/RealLR/n000241_0132_04.png differ
diff --git a/Imgs/RealLR/n000262_0097_01.png b/Imgs/RealLR/n000262_0097_01.png
new file mode 100644
index 0000000000000000000000000000000000000000..f9aedb26d7d4c885385bf6136cf170a2c8a4f4d9
Binary files /dev/null and b/Imgs/RealLR/n000262_0097_01.png differ
diff --git a/Imgs/ShowResults/n000056_0060_01.png b/Imgs/ShowResults/n000056_0060_01.png
new file mode 100644
index 0000000000000000000000000000000000000000..cda40a58b3291d7c88a4930c05085604fb31f3fd
Binary files /dev/null and b/Imgs/ShowResults/n000056_0060_01.png differ
diff --git a/Imgs/ShowResults/n000067_0228_01.png b/Imgs/ShowResults/n000067_0228_01.png
new file mode 100644
index 0000000000000000000000000000000000000000..9a4cf68d2a0d18fe8e195769d333d46d4a778f40
Binary files /dev/null and b/Imgs/ShowResults/n000067_0228_01.png differ
diff --git a/Imgs/ShowResults/n000184_0094_01.png b/Imgs/ShowResults/n000184_0094_01.png
new file mode 100644
index 0000000000000000000000000000000000000000..56fc0bf25806c9a9ec0c47d557636f514ace5b97
Binary files /dev/null and b/Imgs/ShowResults/n000184_0094_01.png differ
diff --git a/Imgs/ShowResults/n000241_0132_04.png b/Imgs/ShowResults/n000241_0132_04.png
new file mode 100644
index 0000000000000000000000000000000000000000..e35b7c0e664ffd443f000fb05bcbb1fd442c0e2e
Binary files /dev/null and b/Imgs/ShowResults/n000241_0132_04.png differ
diff --git a/Imgs/ShowResults/n000262_0097_01.png b/Imgs/ShowResults/n000262_0097_01.png
new file mode 100644
index 0000000000000000000000000000000000000000..62590ddd60020fb4eb2dee611c87b810ebdfc6a5
Binary files /dev/null and b/Imgs/ShowResults/n000262_0097_01.png differ
diff --git a/Imgs/Whole/test1_0.jpg b/Imgs/Whole/test1_0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d32a6d39ea1e3004fbfc37de79f38b008f997583
Binary files /dev/null and b/Imgs/Whole/test1_0.jpg differ
diff --git a/Imgs/Whole/test1_1.jpg b/Imgs/Whole/test1_1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..72391507de1f2bd73f56928c9355e6d4b1b3b3f8
Binary files /dev/null and b/Imgs/Whole/test1_1.jpg differ
diff --git a/Imgs/Whole/test1_2.jpg b/Imgs/Whole/test1_2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..647256df8b458cbf98906f211632ab050dca0e06
Binary files /dev/null and b/Imgs/Whole/test1_2.jpg differ
diff --git a/Imgs/Whole/test1_3.jpg b/Imgs/Whole/test1_3.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..12126d4b26a52903e4a51eb8a3acaee51228c8a1
Binary files /dev/null and b/Imgs/Whole/test1_3.jpg differ
diff --git a/Imgs/Whole/test2_0.jpg b/Imgs/Whole/test2_0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8ec5fac179590e9e83656ad3a4a5cbfc9addc711
Binary files /dev/null and b/Imgs/Whole/test2_0.jpg differ
diff --git a/Imgs/Whole/test2_1.jpg b/Imgs/Whole/test2_1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..24faf15ee5a326e4a21ea5dd235768e8920aac64
Binary files /dev/null and b/Imgs/Whole/test2_1.jpg differ
diff --git a/Imgs/Whole/test2_2.jpg b/Imgs/Whole/test2_2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2f86fa9e4a2a0ee6b8cf748b3fd599e6fba82d80
Binary files /dev/null and b/Imgs/Whole/test2_2.jpg differ
diff --git a/Imgs/Whole/test2_3.jpg b/Imgs/Whole/test2_3.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..fd0acbc0e03a41a2ac33e11071a4cc19fcf142d3
Binary files /dev/null and b/Imgs/Whole/test2_3.jpg differ
diff --git a/Imgs/Whole/test5_0.jpg b/Imgs/Whole/test5_0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..03cd688650bd4b337d15e15c3a6ad79c48f25eb4
Binary files /dev/null and b/Imgs/Whole/test5_0.jpg differ
diff --git a/Imgs/Whole/test5_1.jpg b/Imgs/Whole/test5_1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c1121f7b335e22a77d8e2473b00c8cf064ed6aaf
Binary files /dev/null and b/Imgs/Whole/test5_1.jpg differ
diff --git a/Imgs/Whole/test5_2.jpg b/Imgs/Whole/test5_2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..12f57ba8f70004ddd86eea8ed68815a3431ca23e
Binary files /dev/null and b/Imgs/Whole/test5_2.jpg differ
diff --git a/Imgs/Whole/test5_3.jpg b/Imgs/Whole/test5_3.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7a39a6d41c87501a7a4ad84b55e1fdb1a3241493
Binary files /dev/null and b/Imgs/Whole/test5_3.jpg differ
diff --git a/Imgs/pipeline_a.png b/Imgs/pipeline_a.png
new file mode 100644
index 0000000000000000000000000000000000000000..9a62bd821495874de3545d3c31748abfb9e6a2ae
Binary files /dev/null and b/Imgs/pipeline_a.png differ
diff --git a/Imgs/pipeline_b.png b/Imgs/pipeline_b.png
new file mode 100644
index 0000000000000000000000000000000000000000..f3170eb89e00500343738a51f608510e5a435ad4
Binary files /dev/null and b/Imgs/pipeline_b.png differ
diff --git a/README.md b/README.md
index 0b04060ac2511ead86987e4fbcbbfe9246b01fe8..08297cd71ad65292609b42a2970a269e65974f3e 100644
--- a/README.md
+++ b/README.md
@@ -1,45 +1,171 @@
----
-title: Deep Multi Scale
-emoji: 🦀
-colorFrom: green
-colorTo: green
-sdk: gradio
-app_file: app.py
-pinned: false
----
+## [Blind Face Restoration via Deep Multi-scale Component Dictionaries](https://arxiv.org/pdf/2008.00418.pdf)
-# Configuration
+>##### __Note: This branch contains all the restoration results, including 512×512 face region and the final result by putting the enhanced face to the origial input. The former version that can only generate the face result is put in [master branch](https://github.com/csxmli2016/DFDNet/tree/master)__
-`title`: _string_
-Display title for the Space
-`emoji`: _string_
-Space emoji (emoji-only character allowed)
+
+Overview of our proposed method. It mainly contains two parts: (a) the off-line generation of multi-scale component dictionaries from large amounts of high-quality images, which have diverse poses and expressions. K-means is adopted to generate K clusters for each component (i.e., left/right eyes, nose and mouth) on different feature scales. (b) The restoration process and dictionary feature transfer (DFT) block that are utilized to provide the reference details in a progressive manner. Here, DFT-i block takes the Scale-i component dictionaries for reference in the same feature level.
+
+
-`colorFrom`: _string_
-Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
+
+(a) Offline generation of multi-scale component dictionaries.
+
+(b) Architecture of our DFDNet for dictionary feature transfer.
-`colorTo`: _string_
-Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
-`sdk`: _string_
-Can be either `gradio`, `streamlit`, or `static`
+## Pre-train Models and dictionaries
+Downloading from the following url and put them into ./.
+- [BaiduNetDisk](https://pan.baidu.com/s/1K4fzjPiezVSMl5NjHoJCGQ) (s9ht)
+- [GoogleDrive](https://drive.google.com/drive/folders/1bayYIUMCSGmoFPyd4Uu2Uwn347RW-vl5?usp=sharing)
-`sdk_version` : _string_
-Only applicable for `streamlit` SDK.
-See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
+The folder structure should be:
+
+ .
+ ├── checkpoints
+ │ ├── facefh_dictionary
+ │ │ └── latest_net_G.pth
+ ├── weights
+ │ └── vgg19.pth
+ ├── DictionaryCenter512
+ │ ├── right_eye_256_center.npy
+ │ ├── right_eye_128_center.npy
+ │ ├── right_eye_64_center.npy
+ │ ├── right_eye_32_center.npy
+ │ └── ...
+ └── ...
-`app_file`: _string_
-Path to your main application file (which contains either `gradio` or `streamlit` Python code, or `static` html code).
-Path is relative to the root of the repository.
+## Prerequisites
+>([Video Installation Tutorial](https://www.youtube.com/watch?v=OTqGYMSKGF4). Thanks for [bycloudump](https://www.youtube.com/channel/UCfg9ux4m8P0YDITTPptrmLg)'s tremendous help.)
+- Pytorch (≥1.1 is recommended)
+- dlib
+- dominate
+- cv2
+- tqdm
+- [face-alignment](https://github.com/1adrianb/face-alignment)
+ ```bash
+ cd ./FaceLandmarkDetection
+ python setup.py install
+ cd ..
+ ```
+
-`models`: _List[string]_
-HF model IDs (like "gpt2" or "deepset/roberta-base-squad2") used in the Space.
-Will be parsed automatically from your code if not specified here.
+## Testing
+```bash
+python test_FaceDict.py --test_path ./TestData/TestWhole --results_dir ./Results/TestWholeResults --upscale_factor 4 --gpu_ids 0
+```
+#### __Four parameters can be changed for flexible usage:__
+```
+--test_path # test image path
+--results_dir # save the results path
+--upscale_factor # the upsample factor for the final result
+--gpu_ids # gpu id. if use cpu, set gpu_ids=-1
+```
-`datasets`: _List[string]_
-HF dataset IDs (like "common_voice" or "oscar-corpus/OSCAR-2109") used in the Space.
-Will be parsed automatically from your code if not specified here.
+>Note: our DFDNet can only generate 512×512 face result for any given face image.
+
+#### __Result path contains the following folder:__
+- Step0_Input: ```# Save the input image.```
+- Step1_AffineParam: ```# Save the crop and align parameters for copying the face result to the original input.```
+- Step1_CropImg: ```# Save the cropped face images and resize them to 512×512.```
+- Step2_Landmarks: ```# Save the facial landmarks for RoIAlign.```
+- Step3_RestoreCropFace: ```# Save the face restoration result (512×512).```
+- Step4_FinalResults: ```# Save the final restoration result by putting the enhanced face to the original input.```
+
+## Some plausible restoration results on real low-quality images
+
+
+
+ | Input | Crop and Align | Restore Face | Final Results (UpScaleWhole=4) |
+
+
+
+
+ |
+
+
+ |
+
+
+ |
+
+
+ |
+
+
+
+
+ |
+
+
+ |
+
+
+ |
+
+
+ |
+
+
+
+
+ |
+
+
+ |
+
+
+ |
+
+
+ |
+
+
+
+
+ |
+
+
+ |
+
+
+ |
+
+
+ |
+
+
+
+
+ |
+
+
+ |
+
+
+ |
+
+
+ |
+
+
+
+
+## TO DO LIST (if possible)
+- [ ] Enhance all the faces in one image.
+- [ ] Enhance the background.
+
+
+## Citation
+
+```
+@InProceedings{Li_2020_ECCV,
+author = {Li, Xiaoming and Chen, Chaofeng and Zhou, Shangchen and Lin, Xianhui and Zuo, Wangmeng and Zhang, Lei},
+title = {Blind Face Restoration via Deep Multi-scale Component Dictionaries},
+booktitle = {ECCV},
+year = {2020}
+}
+```
+
+
This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
-`pinned`: _boolean_
-Whether the Space stays on top of your list.
diff --git a/TestData/TestWhole/test1.jpg b/TestData/TestWhole/test1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..dfc2edf0dd62baa5e6d254d15b20d60fe999f86a
Binary files /dev/null and b/TestData/TestWhole/test1.jpg differ
diff --git a/TestData/TestWhole/test2.jpg b/TestData/TestWhole/test2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6cb36b9ee462c1e86f90e88cfaaaa5321512a42d
Binary files /dev/null and b/TestData/TestWhole/test2.jpg differ
diff --git a/TestData/TestWhole/test3.jpg b/TestData/TestWhole/test3.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3f829f02de8f8ec16fe4488f5c00fd64c43fbf07
Binary files /dev/null and b/TestData/TestWhole/test3.jpg differ
diff --git a/TestData/TestWhole/test4.jpg b/TestData/TestWhole/test4.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..84d1182cd011482530fffe35e8b875f0c8608eb3
Binary files /dev/null and b/TestData/TestWhole/test4.jpg differ
diff --git a/TestData/TestWhole/test5.jpg b/TestData/TestWhole/test5.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1955cae069630b0b37124a5e8719faa2c4383063
Binary files /dev/null and b/TestData/TestWhole/test5.jpg differ
diff --git a/TestData/TestWhole/test6.png b/TestData/TestWhole/test6.png
new file mode 100644
index 0000000000000000000000000000000000000000..0e3c871d4599d8a06a12a60417da879c038af776
Binary files /dev/null and b/TestData/TestWhole/test6.png differ
diff --git a/data/MotionBlurKernel/m_01.mat b/data/MotionBlurKernel/m_01.mat
new file mode 100644
index 0000000000000000000000000000000000000000..e25bac95b9723201e4522bb5c4ae63bafc0a2cec
Binary files /dev/null and b/data/MotionBlurKernel/m_01.mat differ
diff --git a/data/MotionBlurKernel/m_02.mat b/data/MotionBlurKernel/m_02.mat
new file mode 100644
index 0000000000000000000000000000000000000000..de05fde3f92121df1e1d8d7fc0ed3bd9575c2c90
Binary files /dev/null and b/data/MotionBlurKernel/m_02.mat differ
diff --git a/data/MotionBlurKernel/m_03.mat b/data/MotionBlurKernel/m_03.mat
new file mode 100644
index 0000000000000000000000000000000000000000..98a9aaf3701af01d2e16b3e6d9b0760a55eb5abf
Binary files /dev/null and b/data/MotionBlurKernel/m_03.mat differ
diff --git a/data/MotionBlurKernel/m_04.mat b/data/MotionBlurKernel/m_04.mat
new file mode 100644
index 0000000000000000000000000000000000000000..9fc10561b7b3c6493a014c1e803798e2b395972f
Binary files /dev/null and b/data/MotionBlurKernel/m_04.mat differ
diff --git a/data/MotionBlurKernel/m_05.mat b/data/MotionBlurKernel/m_05.mat
new file mode 100644
index 0000000000000000000000000000000000000000..3989860435391f9d0dafc94567a02d60bcfc957b
Binary files /dev/null and b/data/MotionBlurKernel/m_05.mat differ
diff --git a/data/MotionBlurKernel/m_06.mat b/data/MotionBlurKernel/m_06.mat
new file mode 100644
index 0000000000000000000000000000000000000000..72d69530a288c3fd3779821c3162a1011d9dd795
Binary files /dev/null and b/data/MotionBlurKernel/m_06.mat differ
diff --git a/data/MotionBlurKernel/m_07.mat b/data/MotionBlurKernel/m_07.mat
new file mode 100644
index 0000000000000000000000000000000000000000..afcf374bdb69b8f29f0f63e23818031f7ba718e1
Binary files /dev/null and b/data/MotionBlurKernel/m_07.mat differ
diff --git a/data/MotionBlurKernel/m_08.mat b/data/MotionBlurKernel/m_08.mat
new file mode 100644
index 0000000000000000000000000000000000000000..97fdc99e8695c9a24c3c38053fa812806465e68c
Binary files /dev/null and b/data/MotionBlurKernel/m_08.mat differ
diff --git a/data/MotionBlurKernel/m_09.mat b/data/MotionBlurKernel/m_09.mat
new file mode 100644
index 0000000000000000000000000000000000000000..2b580721efe97c8c2048f6a763f0e2d9a73bd90c
Binary files /dev/null and b/data/MotionBlurKernel/m_09.mat differ
diff --git a/data/MotionBlurKernel/m_10.mat b/data/MotionBlurKernel/m_10.mat
new file mode 100644
index 0000000000000000000000000000000000000000..fbf1e9b9e900e5575c80ebbf9a5e44c9866d7899
Binary files /dev/null and b/data/MotionBlurKernel/m_10.mat differ
diff --git a/data/MotionBlurKernel/m_11.mat b/data/MotionBlurKernel/m_11.mat
new file mode 100644
index 0000000000000000000000000000000000000000..13073988562afadec1b06d7e2aa5a149931e4270
Binary files /dev/null and b/data/MotionBlurKernel/m_11.mat differ
diff --git a/data/MotionBlurKernel/m_12.mat b/data/MotionBlurKernel/m_12.mat
new file mode 100644
index 0000000000000000000000000000000000000000..8238a8db58ccdcfc5a767ce864ae9fca2148d702
Binary files /dev/null and b/data/MotionBlurKernel/m_12.mat differ
diff --git a/data/MotionBlurKernel/m_13.mat b/data/MotionBlurKernel/m_13.mat
new file mode 100644
index 0000000000000000000000000000000000000000..bd9f6bdebf79d8a1c201833b938cc54f143f0111
Binary files /dev/null and b/data/MotionBlurKernel/m_13.mat differ
diff --git a/data/MotionBlurKernel/m_14.mat b/data/MotionBlurKernel/m_14.mat
new file mode 100644
index 0000000000000000000000000000000000000000..cde1d24dc67470cde9d79c2741f290395dfe27f3
Binary files /dev/null and b/data/MotionBlurKernel/m_14.mat differ
diff --git a/data/MotionBlurKernel/m_15.mat b/data/MotionBlurKernel/m_15.mat
new file mode 100644
index 0000000000000000000000000000000000000000..3538ce4a1d6c78344ddff7325d4e9ce878fa464e
Binary files /dev/null and b/data/MotionBlurKernel/m_15.mat differ
diff --git a/data/MotionBlurKernel/m_16.mat b/data/MotionBlurKernel/m_16.mat
new file mode 100644
index 0000000000000000000000000000000000000000..3e3f9c6d5f6ec1d3a687a706779c7c42405b84be
Binary files /dev/null and b/data/MotionBlurKernel/m_16.mat differ
diff --git a/data/MotionBlurKernel/m_17.mat b/data/MotionBlurKernel/m_17.mat
new file mode 100644
index 0000000000000000000000000000000000000000..eb9b331268f6671b0f9d5a3920ccc7eca89413b3
Binary files /dev/null and b/data/MotionBlurKernel/m_17.mat differ
diff --git a/data/MotionBlurKernel/m_18.mat b/data/MotionBlurKernel/m_18.mat
new file mode 100644
index 0000000000000000000000000000000000000000..7412fce0daf0ed74af4877c13e5960f5e3174468
Binary files /dev/null and b/data/MotionBlurKernel/m_18.mat differ
diff --git a/data/MotionBlurKernel/m_19.mat b/data/MotionBlurKernel/m_19.mat
new file mode 100644
index 0000000000000000000000000000000000000000..2c06d2aac4a33fcebaebc2d8b1c65a34e07ffb14
Binary files /dev/null and b/data/MotionBlurKernel/m_19.mat differ
diff --git a/data/MotionBlurKernel/m_20.mat b/data/MotionBlurKernel/m_20.mat
new file mode 100644
index 0000000000000000000000000000000000000000..70385f34b2b0605cca198ba14a3276adac2976af
Binary files /dev/null and b/data/MotionBlurKernel/m_20.mat differ
diff --git a/data/MotionBlurKernel/m_21.mat b/data/MotionBlurKernel/m_21.mat
new file mode 100644
index 0000000000000000000000000000000000000000..773d50937455be6714f324ab7c7967bc02287aaf
Binary files /dev/null and b/data/MotionBlurKernel/m_21.mat differ
diff --git a/data/MotionBlurKernel/m_22.mat b/data/MotionBlurKernel/m_22.mat
new file mode 100644
index 0000000000000000000000000000000000000000..c051c507f9e46efe768e4308a9b555b2800c28f9
Binary files /dev/null and b/data/MotionBlurKernel/m_22.mat differ
diff --git a/data/MotionBlurKernel/m_23.mat b/data/MotionBlurKernel/m_23.mat
new file mode 100644
index 0000000000000000000000000000000000000000..48303ad2d9f4ef38ded3a52fd0ce1fd2f39651e3
Binary files /dev/null and b/data/MotionBlurKernel/m_23.mat differ
diff --git a/data/MotionBlurKernel/m_24.mat b/data/MotionBlurKernel/m_24.mat
new file mode 100644
index 0000000000000000000000000000000000000000..129f770db08b5fe1dea0c2335a7d466183935383
Binary files /dev/null and b/data/MotionBlurKernel/m_24.mat differ
diff --git a/data/MotionBlurKernel/m_25.mat b/data/MotionBlurKernel/m_25.mat
new file mode 100644
index 0000000000000000000000000000000000000000..42f2aa675a5e7660a5a3b48846a710839c03402b
Binary files /dev/null and b/data/MotionBlurKernel/m_25.mat differ
diff --git a/data/MotionBlurKernel/m_26.mat b/data/MotionBlurKernel/m_26.mat
new file mode 100644
index 0000000000000000000000000000000000000000..480f0694f27dc46970022a289560fa36a421e73b
Binary files /dev/null and b/data/MotionBlurKernel/m_26.mat differ
diff --git a/data/MotionBlurKernel/m_27.mat b/data/MotionBlurKernel/m_27.mat
new file mode 100644
index 0000000000000000000000000000000000000000..7e9cbb553b6775bdd8f66c7868ce1f6b345c7ec8
Binary files /dev/null and b/data/MotionBlurKernel/m_27.mat differ
diff --git a/data/MotionBlurKernel/m_28.mat b/data/MotionBlurKernel/m_28.mat
new file mode 100644
index 0000000000000000000000000000000000000000..03d3aa5eb2c02bf50435c036beb0cc0e95a7febe
Binary files /dev/null and b/data/MotionBlurKernel/m_28.mat differ
diff --git a/data/MotionBlurKernel/m_29.mat b/data/MotionBlurKernel/m_29.mat
new file mode 100644
index 0000000000000000000000000000000000000000..e238ed66b930a38b0116ebba8ce913403d7ea086
Binary files /dev/null and b/data/MotionBlurKernel/m_29.mat differ
diff --git a/data/MotionBlurKernel/m_30.mat b/data/MotionBlurKernel/m_30.mat
new file mode 100644
index 0000000000000000000000000000000000000000..700e4c230c199f1c153e634fa6574dd158307124
Binary files /dev/null and b/data/MotionBlurKernel/m_30.mat differ
diff --git a/data/MotionBlurKernel/m_31.mat b/data/MotionBlurKernel/m_31.mat
new file mode 100644
index 0000000000000000000000000000000000000000..f335141069b0797699bfe758dddc96a1da29e295
Binary files /dev/null and b/data/MotionBlurKernel/m_31.mat differ
diff --git a/data/MotionBlurKernel/m_32.mat b/data/MotionBlurKernel/m_32.mat
new file mode 100644
index 0000000000000000000000000000000000000000..5acf9b436bcc1cf1f7d1980cff0893ee392e6548
Binary files /dev/null and b/data/MotionBlurKernel/m_32.mat differ
diff --git a/data/__init__.py b/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2193cc788f4a263e13cf4890ff7c2231ce951776
--- /dev/null
+++ b/data/__init__.py
@@ -0,0 +1,80 @@
+# -- coding: utf-8 --
+import importlib
+import torch.utils.data
+from data.base_data_loader import BaseDataLoader
+from data.base_dataset import BaseDataset
+
+
+def find_dataset_using_name(dataset_name):
+
+ # Given the option --dataset_mode [datasetname],
+ # the file "data/datasetname_dataset.py"
+ # will be imported.
+ dataset_filename = "data." + dataset_name + "_dataset"
+ datasetlib = importlib.import_module(dataset_filename)
+
+ # In the file, the class called DatasetNameDataset() will
+ # be instantiated. It has to be a subclass of BaseDataset,
+ # and it is case-insensitive.
+ dataset = None
+ target_dataset_name = dataset_name.replace('_', '') + 'dataset'
+ for name, cls in datasetlib.__dict__.items():
+ if name.lower() == target_dataset_name.lower() \
+ and issubclass(cls, BaseDataset):
+ dataset = cls
+
+ if dataset is None:
+ print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
+ exit(0)
+
+ return dataset
+
+
+def get_option_setter(dataset_name):
+ dataset_class = find_dataset_using_name(dataset_name)
+ return dataset_class.modify_commandline_options
+
+
+def create_dataset(opt):
+ dataset = find_dataset_using_name(opt.dataset_mode)
+ instance = dataset()
+ instance.initialize(opt)
+ print("dataset [%s] was created" % (instance.name()))
+ return instance
+
+
+def CreateDataLoader(opt):
+ data_loader = CustomDatasetDataLoader()
+ data_loader.initialize(opt)
+ return data_loader
+
+
+# Wrapper class of Dataset class that performs
+# multi-threaded data loading
+class CustomDatasetDataLoader(BaseDataLoader):
+ def name(self):
+ return 'CustomDatasetDataLoader'
+
+ def initialize(self, opt):
+ BaseDataLoader.initialize(self, opt)
+ self.dataset = create_dataset(opt)
+ self.dataloader = torch.utils.data.DataLoader(
+ self.dataset,
+ batch_size=opt.batchSize,
+ shuffle=not opt.serial_batches,
+ num_workers=int(opt.nThreads))
+ # DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)
+ # 加载的数据集,DataSet对象,是否打乱,样本抽样,使用多线程加载的进程数,0表示不使用多线程,如何将多样本数据拼接成一个batch,是否将数据保存到pin memory,dataset种数据可数可能不是\
+ # 一个batch_size的整数倍,drop_last 为True将多出来不足一个batch的数据丢弃
+
+ def load_data(self):
+ return self
+
+ def __len__(self):
+ return min(len(self.dataset), self.opt.max_dataset_size)
+
+ def __iter__(self):
+ for i, data in enumerate(self.dataloader):
+ if i * self.opt.batchSize >= self.opt.max_dataset_size:
+ break
+ yield data
diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..3eabcca633ec86e152f6fec4996aea9f541a6d7e
--- /dev/null
+++ b/data/aligned_dataset.py
@@ -0,0 +1,137 @@
+# -- coding: utf-8 --
+import os.path
+import random
+import torchvision.transforms as transforms
+import torch
+from data.base_dataset import BaseDataset
+from data.image_folder import make_dataset
+from PIL import Image, ImageFilter
+import numpy as np
+import cv2
+import math
+from util import util
+from scipy.io import loadmat
+from PIL import Image
+import PIL
+
+
+class AlignedDataset(BaseDataset):
+ @staticmethod
+ def modify_commandline_options(parser, is_train):
+ return parser
+
+ def initialize(self, opt):
+ self.opt = opt
+ self.root = opt.dataroot
+ self.partpath = opt.partroot
+ self.dir_AB = os.path.join(opt.dataroot, opt.phase)
+ self.AB_paths = sorted(make_dataset(self.dir_AB))
+ self.is_real = opt.is_real
+ # assert(opt.resize_or_crop == 'resize_and_crop')
+ assert(opt.resize_or_crop == 'degradation')
+
+ def AddNoise(self,img): # noise
+ if random.random() > 0.9: #
+ return img
+ self.sigma = np.random.randint(1, 11)
+ img_tensor = torch.from_numpy(np.array(img)).float()
+ noise = torch.randn(img_tensor.size()).mul_(self.sigma/1.0)
+
+ noiseimg = torch.clamp(noise+img_tensor,0,255)
+ return Image.fromarray(np.uint8(noiseimg.numpy()))
+
+ def AddBlur(self,img): # gaussian blur or motion blur
+ if random.random() > 0.9: #
+ return img
+ img = np.array(img)
+ if random.random() > 0.35: ##gaussian blur
+ blursize = random.randint(1,17) * 2 + 1 ##3,5,7,9,11,13,15
+ blursigma = random.randint(3, 20)
+ img = cv2.GaussianBlur(img, (blursize,blursize), blursigma/10)
+ else: #motion blur
+ M = random.randint(1,32)
+ KName = './data/MotionBlurKernel/m_%02d.mat' % M
+ k = loadmat(KName)['kernel']
+ k = k.astype(np.float32)
+ k /= np.sum(k)
+ img = cv2.filter2D(img,-1,k)
+ return Image.fromarray(img)
+
+ def AddDownSample(self,img): # downsampling
+ if random.random() > 0.95: #
+ return img
+ sampler = random.randint(20, 100)*1.0
+ img = img.resize((int(self.opt.fineSize/sampler*10.0), int(self.opt.fineSize/sampler*10.0)), Image.BICUBIC)
+ return img
+
+ def AddJPEG(self,img): # JPEG compression
+ if random.random() > 0.6: #
+ return img
+ imQ = random.randint(40, 80)
+ img = np.array(img)
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY),imQ] # (0,100),higher is better,default is 95
+ _, encA = cv2.imencode('.jpg',img,encode_param)
+ img = cv2.imdecode(encA,1)
+ return Image.fromarray(img)
+
+ def AddUpSample(self,img):
+ return img.resize((self.opt.fineSize, self.opt.fineSize), Image.BICUBIC)
+
+ def __getitem__(self, index): #
+
+ AB_path = self.AB_paths[index]
+ Imgs = Image.open(AB_path).convert('RGB')
+ # #
+ A = Imgs.resize((self.opt.fineSize, self.opt.fineSize))
+ A = transforms.ColorJitter(0.3, 0.3, 0.3, 0)(A)
+ C = A
+ A = self.AddUpSample(self.AddJPEG(self.AddNoise(self.AddDownSample(self.AddBlur(A)))))
+
+ tmps = AB_path.split('/')
+ ImgName = tmps[-1]
+ Part_locations = self.get_part_location(self.partpath, ImgName, 2)
+
+ A = transforms.ToTensor()(A) #
+ C = transforms.ToTensor()(C)
+
+ ##
+ A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A) #
+ C = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(C) #
+ return {'A':A, 'C':C, 'A_paths': AB_path,'Part_locations': Part_locations}
+
+ def get_part_location(self, landmarkpath, imgname, downscale=1):
+ Landmarks = []
+ with open(os.path.join(landmarkpath,imgname+'.txt'),'r') as f:
+ for line in f:
+ tmp = [np.float(i) for i in line.split(' ') if i != '\n']
+ Landmarks.append(tmp)
+ Landmarks = np.array(Landmarks)/downscale # 512 * 512
+
+ Map_LE = list(np.hstack((range(17,22), range(36,42))))
+ Map_RE = list(np.hstack((range(22,27), range(42,48))))
+ Map_NO = list(range(29,36))
+ Map_MO = list(range(48,68))
+ #left eye
+ Mean_LE = np.mean(Landmarks[Map_LE],0)
+ L_LE = np.max((np.max(np.max(Landmarks[Map_LE],0) - np.min(Landmarks[Map_LE],0))/2,16))
+ Location_LE = np.hstack((Mean_LE - L_LE + 1, Mean_LE + L_LE)).astype(int)
+ #right eye
+ Mean_RE = np.mean(Landmarks[Map_RE],0)
+ L_RE = np.max((np.max(np.max(Landmarks[Map_RE],0) - np.min(Landmarks[Map_RE],0))/2,16))
+ Location_RE = np.hstack((Mean_RE - L_RE + 1, Mean_RE + L_RE)).astype(int)
+ #nose
+ Mean_NO = np.mean(Landmarks[Map_NO],0)
+ L_NO = np.max((np.max(np.max(Landmarks[Map_NO],0) - np.min(Landmarks[Map_NO],0))/2,16))
+ Location_NO = np.hstack((Mean_NO - L_NO + 1, Mean_NO + L_NO)).astype(int)
+ #mouth
+ Mean_MO = np.mean(Landmarks[Map_MO],0)
+ L_MO = np.max((np.max(np.max(Landmarks[Map_MO],0) - np.min(Landmarks[Map_MO],0))/2,16))
+
+ Location_MO = np.hstack((Mean_MO - L_MO + 1, Mean_MO + L_MO)).astype(int)
+ return Location_LE, Location_RE, Location_NO, Location_MO
+
+ def __len__(self): #
+ return len(self.AB_paths)
+
+ def name(self):
+ return 'AlignedDataset'
diff --git a/data/base_data_loader.py b/data/base_data_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..5abf22edfd9baf7e25d281d1f7bfe714b6b9294e
--- /dev/null
+++ b/data/base_data_loader.py
@@ -0,0 +1,10 @@
+class BaseDataLoader():
+ def __init__(self):
+ pass
+
+ def initialize(self, opt):
+ self.opt = opt
+ pass
+
+ def load_data(self):
+ return None
diff --git a/data/base_dataset.py b/data/base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..de1cee3e7296ac87a6470ab91af70da4ca6a2fe7
--- /dev/null
+++ b/data/base_dataset.py
@@ -0,0 +1,104 @@
+import torch.utils.data as data
+from PIL import Image
+import torchvision.transforms as transforms
+
+
+class BaseDataset(data.Dataset):
+ def __init__(self):
+ super(BaseDataset, self).__init__()
+
+ def name(self):
+ return 'BaseDataset'
+
+ @staticmethod
+ def modify_commandline_options(parser, is_train):
+ return parser
+
+ def initialize(self, opt):
+ pass
+
+ def __len__(self):
+ return 0
+
+
+def get_transform(opt):
+ transform_list = []
+ if opt.resize_or_crop == 'resize_and_crop':
+
+ osize = [opt.loadSize, opt.loadSize]
+ transform_list.append(transforms.Resize(osize, Image.BICUBIC))
+ transform_list.append(transforms.RandomCrop(opt.fineSize))
+ elif opt.resize_or_crop == 'crop':
+ transform_list.append(transforms.RandomCrop(opt.fineSize))
+ elif opt.resize_or_crop == 'scale_width':
+ transform_list.append(transforms.Lambda(
+ lambda img: __scale_width(img, opt.fineSize)))
+ elif opt.resize_or_crop == 'scale_width_and_crop':
+ transform_list.append(transforms.Lambda(
+ lambda img: __scale_width(img, opt.loadSize)))
+ transform_list.append(transforms.RandomCrop(opt.fineSize))
+ elif opt.resize_or_crop == 'none':
+ transform_list.append(transforms.Lambda(
+ lambda img: __adjust(img)))
+ else:
+ raise ValueError('--resize_or_crop %s is not a valid option.' % opt.resize_or_crop)
+
+ if opt.isTrain and not opt.no_flip:
+ transform_list.append(transforms.RandomHorizontalFlip())
+
+ transform_list += [transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5))]
+ return transforms.Compose(transform_list)
+
+# just modify the width and height to be multiple of 4
+def __adjust(img):
+ ow, oh = img.size
+
+ # the size needs to be a multiple of this number,
+ # because going through generator network may change img size
+ # and eventually cause size mismatch error
+ mult = 4
+ if ow % mult == 0 and oh % mult == 0:
+ return img
+ w = (ow - 1) // mult
+ w = (w + 1) * mult
+ h = (oh - 1) // mult
+ h = (h + 1) * mult
+
+ if ow != w or oh != h:
+ __print_size_warning(ow, oh, w, h)
+
+ return img.resize((w, h), Image.BICUBIC)
+
+
+def __scale_width(img, target_width):
+ ow, oh = img.size
+
+ # the size needs to be a multiple of this number,
+ # because going through generator network may change img size
+ # and eventually cause size mismatch error
+ mult = 4
+ assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult
+ if (ow == target_width and oh % mult == 0):
+ return img
+ w = target_width
+ target_height = int(target_width * oh / ow)
+ m = (target_height - 1) // mult
+ h = (m + 1) * mult
+
+ if target_height != h:
+ __print_size_warning(target_width, target_height, w, h)
+
+ return img.resize((w, h), Image.BICUBIC)
+
+
+def __print_size_warning(ow, oh, w, h):
+ if not hasattr(__print_size_warning, 'has_printed'):
+ print("The image size needs to be a multiple of 4. "
+ "The loaded image size was (%d, %d), so it was adjusted to "
+ "(%d, %d). This adjustment will be done to all images "
+ "whose sizes are not multiples of 4" % (ow, oh, w, h))
+ __print_size_warning.has_printed = True
+
+
diff --git a/data/image_folder.py b/data/image_folder.py
new file mode 100644
index 0000000000000000000000000000000000000000..23df4c8bc6b692b84d26ae85d81e8a5ac3c5aad4
--- /dev/null
+++ b/data/image_folder.py
@@ -0,0 +1,69 @@
+###############################################################################
+# Code from
+# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
+# Modified the original code so that it also loads images from the current
+# directory as well as the subdirectories
+###############################################################################
+
+import torch.utils.data as data
+
+from PIL import Image
+import os
+import os.path
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def make_dataset(dirs):
+ images = []
+ assert os.path.isdir(dirs), '%s is not a valid directory' % dirs
+
+ for root, _, fnames in sorted(os.walk(dirs)):
+ fnames.sort()
+ for fname in fnames:
+ if is_image_file(fname):
+ path = os.path.join(root, fname)
+ images.append(path)
+
+ return images
+
+
+def default_loader(path):
+ return Image.open(path).convert('RGB')
+
+
+class ImageFolder(data.Dataset):
+
+ def __init__(self, root, transform=None, return_paths=False,
+ loader=default_loader):
+ imgs = make_dataset(root)
+ if len(imgs) == 0:
+ raise(RuntimeError("Found 0 images in: " + root + "\n"
+ "Supported image extensions are: " +
+ ",".join(IMG_EXTENSIONS)))
+
+ self.root = root
+ self.imgs = imgs
+ self.transform = transform
+ self.return_paths = return_paths
+ self.loader = loader
+
+ def __getitem__(self, index):
+ path = self.imgs[index]
+ img = self.loader(path)
+ if self.transform is not None:
+ img = self.transform(img)
+ if self.return_paths:
+ return img, path
+ else:
+ return img
+
+ def __len__(self):
+ return len(self.imgs)
diff --git a/data/single_dataset.py b/data/single_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9f515b024b3c91aedfea824f075cf3a81c4376e
--- /dev/null
+++ b/data/single_dataset.py
@@ -0,0 +1,42 @@
+import os.path
+from data.base_dataset import BaseDataset, get_transform
+from data.image_folder import make_dataset
+from PIL import Image
+
+
+class SingleDataset(BaseDataset):
+ @staticmethod
+ def modify_commandline_options(parser, is_train):
+ return parser
+
+ def initialize(self, opt):
+ self.opt = opt
+ self.root = opt.dataroot
+ self.dir_A = os.path.join(opt.dataroot)
+
+ self.A_paths = make_dataset(self.dir_A)
+
+ self.A_paths = sorted(self.A_paths)
+
+ self.transform = get_transform(opt)
+
+ def __getitem__(self, index):
+ A_path = self.A_paths[index]
+ A_img = Image.open(A_path).convert('RGB')
+ A = self.transform(A_img)
+ if self.opt.which_direction == 'BtoA':
+ input_nc = self.opt.output_nc
+ else:
+ input_nc = self.opt.input_nc
+
+ if input_nc == 1: # RGB to gray
+ tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
+ A = tmp.unsqueeze(0)
+
+ return {'A': A, 'A_paths': A_path}
+
+ def __len__(self):
+ return len(self.A_paths)
+
+ def name(self):
+ return 'SingleImageDataset'
diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..06938b7f777230802305570e655af31be6121e03
--- /dev/null
+++ b/data/unaligned_dataset.py
@@ -0,0 +1,62 @@
+import os.path
+from data.base_dataset import BaseDataset, get_transform
+from data.image_folder import make_dataset
+from PIL import Image
+import random
+
+
+class UnalignedDataset(BaseDataset):
+ @staticmethod
+ def modify_commandline_options(parser, is_train):
+ return parser
+
+ def initialize(self, opt):
+ self.opt = opt
+ self.root = opt.dataroot
+ self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A')
+ self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B')
+
+ self.A_paths = make_dataset(self.dir_A)
+ self.B_paths = make_dataset(self.dir_B)
+
+ self.A_paths = sorted(self.A_paths)
+ self.B_paths = sorted(self.B_paths)
+ self.A_size = len(self.A_paths)
+ self.B_size = len(self.B_paths)
+ self.transform = get_transform(opt)
+
+ def __getitem__(self, index):
+ A_path = self.A_paths[index % self.A_size]
+ if self.opt.serial_batches:
+ index_B = index % self.B_size
+ else:
+ index_B = random.randint(0, self.B_size - 1)
+ B_path = self.B_paths[index_B]
+ # print('(A, B) = (%d, %d)' % (index_A, index_B))
+ A_img = Image.open(A_path).convert('RGB')
+ B_img = Image.open(B_path).convert('RGB')
+
+ A = self.transform(A_img)
+ B = self.transform(B_img)
+ if self.opt.which_direction == 'BtoA':
+ input_nc = self.opt.output_nc
+ output_nc = self.opt.input_nc
+ else:
+ input_nc = self.opt.input_nc
+ output_nc = self.opt.output_nc
+
+ if input_nc == 1: # RGB to gray
+ tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
+ A = tmp.unsqueeze(0)
+
+ if output_nc == 1: # RGB to gray
+ tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
+ B = tmp.unsqueeze(0)
+ return {'A': A, 'B': B,
+ 'A_paths': A_path, 'B_paths': B_path}
+
+ def __len__(self):
+ return max(self.A_size, self.B_size)
+
+ def name(self):
+ return 'UnalignedDataset'
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..259f43fb224809862d90ed0c8b9d234bfa809540
--- /dev/null
+++ b/models/__init__.py
@@ -0,0 +1,37 @@
+import importlib
+from models.base_model import BaseModel
+
+
+def find_model_using_name(model_name):
+ # Given the option --model [modelname],
+ # the file "models/modelname_model.py"
+ # will be imported.
+ model_filename = "models." + model_name + "_model"
+ modellib = importlib.import_module(model_filename)
+
+ # In the file, the class called ModelNameModel() will
+ # be instantiated. It has to be a subclass of BaseModel,
+ # and it is case-insensitive.
+ model = None
+ target_model_name = model_name.replace('_', '') + 'model'
+
+ for name, cls in modellib.__dict__.items():
+ if name.lower() == target_model_name.lower() \
+ and issubclass(cls, BaseModel):
+ model = cls
+ if model is None:
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
+
+ return model
+
+
+def get_option_setter(model_name):
+ model_class = find_model_using_name(model_name)
+ return model_class.modify_commandline_options
+
+
+def create_model(opt):
+ model = find_model_using_name(opt.model)
+ instance = model()
+ instance.initialize(opt)
+ return instance
diff --git a/models/__pycache__/__init__.cpython-38.pyc b/models/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fb0838fd62bb126983c6b1f5648270af07733b8c
Binary files /dev/null and b/models/__pycache__/__init__.cpython-38.pyc differ
diff --git a/models/__pycache__/base_model.cpython-38.pyc b/models/__pycache__/base_model.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0fe02be4b617dc504e316e3aa82b615529c8ba03
Binary files /dev/null and b/models/__pycache__/base_model.cpython-38.pyc differ
diff --git a/models/__pycache__/networks.cpython-38.pyc b/models/__pycache__/networks.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..df2afa669b6ddfd3d9fa66569c3ad064e3c576d2
Binary files /dev/null and b/models/__pycache__/networks.cpython-38.pyc differ
diff --git a/models/__pycache__/test_model.cpython-38.pyc b/models/__pycache__/test_model.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0cb31788851358cd78aa016d2af8e4c7be893fd1
Binary files /dev/null and b/models/__pycache__/test_model.cpython-38.pyc differ
diff --git a/models/base_model.py b/models/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef1e2f7d81fd6d4f74547d54a0a63335ceab2cf6
--- /dev/null
+++ b/models/base_model.py
@@ -0,0 +1,173 @@
+import os
+import torch
+from collections import OrderedDict
+from . import networks
+
+
+class BaseModel():
+
+ # modify parser to add command line options,
+ # and also change the default values if needed
+ @staticmethod
+ def modify_commandline_options(parser, is_train):
+ return parser
+
+ def name(self):
+ return 'BaseModel'
+
+ def initialize(self, opt):
+ self.opt = opt
+ self.gpu_ids = opt.gpu_ids
+ self.isTrain = opt.isTrain
+ self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
+ if opt.resize_or_crop != 'scale_width':
+ torch.backends.cudnn.benchmark = True
+ self.loss_names = []
+ self.model_names = []
+ self.visual_names = []
+ self.image_paths = []
+ # self.optimizers = []
+
+ def set_input(self, input):
+ self.input = input
+
+ def forward(self):
+ pass
+
+ # load and print networks; create schedulers
+ def setup(self, opt, parser=None):
+ if self.isTrain:
+ self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
+ if not self.isTrain or opt.continue_train:
+ self.load_networks(opt.which_epoch)
+ self.print_networks(opt.verbose)
+
+ # make models eval mode during test time
+ def eval(self):
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, 'net' + name)
+ net.eval()
+
+ # used in test time, wrapping `forward` in no_grad() so we don't save
+ # intermediate steps for backprop
+ def test(self):
+ with torch.no_grad():
+ self.forward()
+
+ # get image paths
+ def get_image_paths(self):
+ return self.image_paths
+
+ def optimize_parameters(self):
+ pass
+
+ # update learning rate (called once every epoch)
+ def update_learning_rate(self):
+ for scheduler in self.schedulers:
+ scheduler.step()
+ lr = self.optimizers[0].param_groups[0]['lr']
+ print('learning rate = %.7f' % lr)
+
+ # return visualization images. train.py will display these images, and save the images to a html
+ def get_current_visuals(self):
+ visual_ret = OrderedDict()
+ for name in self.visual_names:
+ if isinstance(name, str):
+ visual_ret[name] = getattr(self, name)
+ return visual_ret
+
+ # return traning losses/errors. train.py will print out these errors as debugging information
+ def get_current_losses(self):
+ errors_ret = OrderedDict()
+ for name in self.loss_names:
+ if isinstance(name, str):
+ # float(...) works for both scalar tensor and float number
+ errors_ret[name] = float(getattr(self, 'loss_' + name))
+ return errors_ret
+
+ # save models to the disk
+ def save_networks(self, which_epoch):
+ for name in self.model_names:
+ if isinstance(name, str):
+ save_filename = '%s_net_%s.pth' % (which_epoch, name)
+ save_path = os.path.join(self.save_dir, save_filename)
+ net = getattr(self, 'net' + name)
+
+ if len(self.gpu_ids) > 0 and torch.cuda.is_available():
+ torch.save(net.module.cpu().state_dict(), save_path)
+ net.cuda(self.gpu_ids[0])
+ else:
+ torch.save(net.cpu().state_dict(), save_path)
+
+ def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
+ key = keys[i]
+ if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
+ if module.__class__.__name__.startswith('InstanceNorm') and \
+ (key == 'running_mean' or key == 'running_var'):
+ if getattr(module, key) is None:
+ state_dict.pop('.'.join(keys))
+ if module.__class__.__name__.startswith('InstanceNorm') and \
+ (key == 'num_batches_tracked'):
+ state_dict.pop('.'.join(keys))
+ else:
+ self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
+
+ # load models from the disk
+ def load_networks(self, which_epoch):
+ for name in self.model_names:
+ if isinstance(name, str):
+ load_filename = '%s_net_%s.pth' % (which_epoch, name)
+ load_path = os.path.join(self.save_dir, load_filename)
+ net = getattr(self, 'net' + name)
+ if isinstance(net, torch.nn.DataParallel):
+ net = net.module
+ # print('loading the model from %s' % load_path)
+ # if you are using PyTorch newer than 0.4 (e.g., built from
+ # GitHub source), you can remove str() on self.device
+ if not os.path.exists(load_path):
+ continue
+ state_dict = torch.load(load_path, map_location=str(self.device))
+ if hasattr(state_dict, '_metadata'):
+ del state_dict._metadata
+
+ # patch InstanceNorm checkpoints prior to 0.4
+ # for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
+ # self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
+ model_dict = net.state_dict()
+ # new_dict = {k: v for k, v in state_dict.items() if k in model_dict.keys()}
+ new_dict = {}
+ for k, v in state_dict.items():
+ if k in model_dict.keys():
+ # print(k)
+ # if k == 'sff_branch.0.sff0.MaskModel.0.weight' or k =='sff_branch.0.sff1.MaskModel.0.weight' or k == 'sff_branch.1.sff0.MaskModel.0.weight' or k =='sff_branch.1.sff1.MaskModel.0.weight' or k == 'sff_branch.2.sff0.MaskModel.0.weight' or k =='sff_branch.2.sff1.MaskModel.0.weight' or k == 'sff_branch.3.sff0.MaskModel.0.weight' or k =='sff_branch.3.sff1.MaskModel.0.weight' or k == 'sff_branch.4.MaskModel.0.weight' :
+ # continue
+ # if 'Mask_CModel.model' in k:
+ # continue
+ new_dict[k] = v
+ model_dict.update(new_dict)
+ net.load_state_dict(model_dict)
+
+ # print network information
+ def print_networks(self, verbose):
+ # print('---------- Networks initialized -------------')
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, 'net' + name)
+ num_params = 0
+ for param in net.parameters():
+ num_params += param.numel()
+ # if verbose:
+ # print(net)
+ # print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
+ # print('-----------------------------------------------')
+
+ # set requies_grad=Fasle to avoid computation
+ def set_requires_grad(self, nets, requires_grad=False):
+ if not isinstance(nets, list):
+ nets = [nets]
+ for net in nets:
+ if net is not None:
+ for param in net.parameters():
+ param.requires_grad = requires_grad
diff --git a/models/networks.py b/models/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..761bfdd2f20b1ede959bad67783b3d1ca7415b9e
--- /dev/null
+++ b/models/networks.py
@@ -0,0 +1,652 @@
+import torch
+import torch.nn as nn
+from torch.nn import init
+import functools
+from torch.optim import lr_scheduler
+import torch.nn.functional as F
+from torch.nn import Parameter as P
+from util import util
+from torchvision import models
+import scipy.io as sio
+import numpy as np
+import scipy.ndimage
+import torch.nn.utils.spectral_norm as SpectralNorm
+
+from torch.autograd import Function
+from math import sqrt
+import random
+import os
+import math
+
+from sync_batchnorm import convert_model
+####
+
+###############################################################################
+# Helper Functions
+###############################################################################
+def get_norm_layer(norm_type='instance'):
+ if norm_type == 'batch':
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
+ elif norm_type == 'instance':
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
+ elif norm_type == 'none':
+ norm_layer = None
+ else:
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
+
+ return norm_layer
+
+
+def get_scheduler(optimizer, opt):
+ if opt.lr_policy == 'lambda':
+ def lambda_rule(epoch):
+ lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
+ return lr_l
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
+ elif opt.lr_policy == 'step':
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
+ elif opt.lr_policy == 'plateau':
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
+ else:
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
+
+ return scheduler
+
+
+def init_weights(net, init_type='normal', gain=0.02):
+ def init_func(m):
+ classname = m.__class__.__name__
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
+ if init_type == 'normal':
+ init.normal_(m.weight.data, 0.0, gain)
+ elif init_type == 'xavier':
+ init.xavier_normal_(m.weight.data, gain=gain)
+ elif init_type == 'kaiming':
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
+ elif init_type == 'orthogonal':
+ init.orthogonal_(m.weight.data, gain=gain)
+ else:
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
+ if hasattr(m, 'bias') and m.bias is not None:
+ init.constant_(m.bias.data, 0.0)
+ elif classname.find('BatchNorm2d') != -1:
+ init.normal_(m.weight.data, 1.0, gain)
+ init.constant_(m.bias.data, 0.0)
+
+ print('initialize network with %s' % init_type)
+ net.apply(init_func)
+
+
+def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], init_flag=True):
+ if len(gpu_ids) > 0:
+ assert(torch.cuda.is_available())
+ net = convert_model(net)
+ net.to(gpu_ids[0])
+ net = torch.nn.DataParallel(net, gpu_ids)
+
+ if init_flag:
+
+ init_weights(net, init_type, gain=init_gain)
+
+ return net
+
+
+# compute adaptive instance norm
+def calc_mean_std(feat, eps=1e-5):
+ # eps is a small value added to the variance to avoid divide-by-zero.
+ size = feat.size()
+ assert (len(size) == 3)
+ C, _ = size[:2]
+ feat_var = feat.contiguous().view(C, -1).var(dim=1) + eps
+ feat_std = feat_var.sqrt().view(C, 1, 1)
+ feat_mean = feat.contiguous().view(C, -1).mean(dim=1).view(C, 1, 1)
+
+ return feat_mean, feat_std
+
+
+def adaptive_instance_normalization(content_feat, style_feat): # content_feat is degraded feature, style is ref feature
+ assert (content_feat.size()[:1] == style_feat.size()[:1])
+ size = content_feat.size()
+ style_mean, style_std = calc_mean_std(style_feat)
+ content_mean, content_std = calc_mean_std(content_feat)
+
+ normalized_feat = (content_feat - content_mean.expand(
+ size)) / content_std.expand(size)
+
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
+
+def calc_mean_std_4D(feat, eps=1e-5):
+ # eps is a small value added to the variance to avoid divide-by-zero.
+ size = feat.size()
+ assert (len(size) == 4)
+ N, C = size[:2]
+ feat_var = feat.view(N, C, -1).var(dim=2) + eps
+ feat_std = feat_var.sqrt().view(N, C, 1, 1)
+ feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
+ return feat_mean, feat_std
+
+def adaptive_instance_normalization_4D(content_feat, style_feat): # content_feat is ref feature, style is degradate feature
+ # assert (content_feat.size()[:2] == style_feat.size()[:2])
+ size = content_feat.size()
+ style_mean, style_std = calc_mean_std_4D(style_feat)
+
+ content_mean, content_std = calc_mean_std_4D(content_feat)
+ normalized_feat = (content_feat - content_mean.expand(
+ size)) / content_std.expand(size)
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
+
+def define_G(which_model_netG, gpu_ids=[]):
+ if which_model_netG == 'UNetDictFace':
+ netG = UNetDictFace(64)
+ init_flag = False
+ else:
+ raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
+ return init_net(netG, 'normal', 0.02, gpu_ids, init_flag)
+
+
+##############################################################################
+# Classes
+############################################################################################################################################
+
+
+def convU(in_channels, out_channels,conv_layer, norm_layer, kernel_size=3, stride=1,dilation=1, bias=True):
+ return nn.Sequential(
+ SpectralNorm(conv_layer(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)),
+# conv_layer(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias),
+# nn.BatchNorm2d(out_channels),
+ nn.LeakyReLU(0.2),
+ SpectralNorm(conv_layer(out_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)),
+ )
+class MSDilateBlock(nn.Module):
+ def __init__(self, in_channels,conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, kernel_size=3, dilation=[1,1,1,1], bias=True):
+ super(MSDilateBlock, self).__init__()
+ self.conv1 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[0], bias=bias)
+ self.conv2 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[1], bias=bias)
+ self.conv3 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[2], bias=bias)
+ self.conv4 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[3], bias=bias)
+ self.convi = SpectralNorm(conv_layer(in_channels*4, in_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size-1)//2, bias=bias))
+ def forward(self, x):
+ conv1 = self.conv1(x)
+ conv2 = self.conv2(x)
+ conv3 = self.conv3(x)
+ conv4 = self.conv4(x)
+ cat = torch.cat([conv1, conv2, conv3, conv4], 1)
+ out = self.convi(cat) + x
+ return out
+
+##############################UNetFace#########################
+class AdaptiveInstanceNorm(nn.Module):
+ def __init__(self, in_channel):
+ super().__init__()
+ self.norm = nn.InstanceNorm2d(in_channel)
+
+ def forward(self, input, style):
+ style_mean, style_std = calc_mean_std_4D(style)
+ out = self.norm(input)
+ size = input.size()
+ out = style_std.expand(size) * out + style_mean.expand(size)
+ return out
+
+class BlurFunctionBackward(Function):
+ @staticmethod
+ def forward(ctx, grad_output, kernel, kernel_flip):
+ ctx.save_for_backward(kernel, kernel_flip)
+
+ grad_input = F.conv2d(
+ grad_output, kernel_flip, padding=1, groups=grad_output.shape[1]
+ )
+ return grad_input
+
+ @staticmethod
+ def backward(ctx, gradgrad_output):
+ kernel, kernel_flip = ctx.saved_tensors
+
+ grad_input = F.conv2d(
+ gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1]
+ )
+ return grad_input, None, None
+
+
+class BlurFunction(Function):
+ @staticmethod
+ def forward(ctx, input, kernel, kernel_flip):
+ ctx.save_for_backward(kernel, kernel_flip)
+
+ output = F.conv2d(input, kernel, padding=1, groups=input.shape[1])
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ kernel, kernel_flip = ctx.saved_tensors
+
+ grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip)
+
+ return grad_input, None, None
+
+blur = BlurFunction.apply
+
+
+class Blur(nn.Module):
+ def __init__(self, channel):
+ super().__init__()
+
+ weight = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32)
+ weight = weight.view(1, 1, 3, 3)
+ weight = weight / weight.sum()
+ weight_flip = torch.flip(weight, [2, 3])
+
+ self.register_buffer('weight', weight.repeat(channel, 1, 1, 1))
+ self.register_buffer('weight_flip', weight_flip.repeat(channel, 1, 1, 1))
+
+ def forward(self, input):
+ return blur(input, self.weight, self.weight_flip)
+
+class EqualLR:
+ def __init__(self, name):
+ self.name = name
+
+ def compute_weight(self, module):
+ weight = getattr(module, self.name + '_orig')
+ fan_in = weight.data.size(1) * weight.data[0][0].numel()
+ return weight * sqrt(2 / fan_in)
+ @staticmethod
+ def apply(module, name):
+ fn = EqualLR(name)
+
+ weight = getattr(module, name)
+ del module._parameters[name]
+ module.register_parameter(name + '_orig', nn.Parameter(weight.data))
+ module.register_forward_pre_hook(fn)
+
+ return fn
+
+ def __call__(self, module, input):
+ weight = self.compute_weight(module)
+ setattr(module, self.name, weight)
+
+def equal_lr(module, name='weight'):
+ EqualLR.apply(module, name)
+ return module
+
+class EqualConv2d(nn.Module):
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ conv = nn.Conv2d(*args, **kwargs)
+ conv.weight.data.normal_()
+ conv.bias.data.zero_()
+ self.conv = equal_lr(conv)
+ def forward(self, input):
+ return self.conv(input)
+
+class NoiseInjection(nn.Module):
+ def __init__(self, channel):
+ super().__init__()
+ self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
+ def forward(self, image, noise):
+ return image + self.weight * noise
+
+class StyledUpBlock(nn.Module):
+ def __init__(self, in_channel, out_channel, kernel_size=3, padding=1,upsample=False):
+ super().__init__()
+ if upsample:
+ self.conv1 = nn.Sequential(
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
+ Blur(out_channel),
+ # EqualConv2d(in_channel, out_channel, kernel_size, padding=padding),
+ SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
+ nn.LeakyReLU(0.2),
+ )
+ else:
+ self.conv1 = nn.Sequential(
+ Blur(in_channel),
+ # EqualConv2d(in_channel, out_channel, kernel_size, padding=padding)
+ SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
+ nn.LeakyReLU(0.2),
+ )
+ self.convup = nn.Sequential(
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
+ # EqualConv2d(out_channel, out_channel, kernel_size, padding=padding),
+ SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
+ nn.LeakyReLU(0.2),
+ # Blur(out_channel),
+ )
+ # self.noise1 = equal_lr(NoiseInjection(out_channel))
+ # self.adain1 = AdaptiveInstanceNorm(out_channel)
+ self.lrelu1 = nn.LeakyReLU(0.2)
+
+ # self.conv2 = EqualConv2d(out_channel, out_channel, kernel_size, padding=padding)
+ # self.noise2 = equal_lr(NoiseInjection(out_channel))
+ # self.adain2 = AdaptiveInstanceNorm(out_channel)
+ # self.lrelu2 = nn.LeakyReLU(0.2)
+
+ self.ScaleModel1 = nn.Sequential(
+ # Blur(in_channel),
+ SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
+ # nn.Conv2d(in_channel,out_channel,3, 1, 1),
+ nn.LeakyReLU(0.2, True),
+ SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1))
+ # nn.Conv2d(out_channel, out_channel, 3, 1, 1)
+ )
+ self.ShiftModel1 = nn.Sequential(
+ # Blur(in_channel),
+ SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
+ # nn.Conv2d(in_channel,out_channel,3, 1, 1),
+ nn.LeakyReLU(0.2, True),
+ SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)),
+ nn.Sigmoid(),
+ # nn.Conv2d(out_channel, out_channel, 3, 1, 1)
+ )
+
+ def forward(self, input, style):
+ out = self.conv1(input)
+# out = self.noise1(out, noise)
+ out = self.lrelu1(out)
+
+ Shift1 = self.ShiftModel1(style)
+ Scale1 = self.ScaleModel1(style)
+ out = out * Scale1 + Shift1
+ # out = self.adain1(out, style)
+ outup = self.convup(out)
+
+ return outup
+
+##############################################################################
+##Face Dictionary
+##############################################################################
+class VGGFeat(torch.nn.Module):
+ """
+ Input: (B, C, H, W), RGB, [-1, 1]
+ """
+ def __init__(self, weight_path='./weights/vgg19.pth'):
+ super().__init__()
+ self.model = models.vgg19(pretrained=False)
+ self.build_vgg_layers()
+
+ self.model.load_state_dict(torch.load(weight_path))
+
+ self.register_parameter("RGB_mean", nn.Parameter(torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)))
+ self.register_parameter("RGB_std", nn.Parameter(torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)))
+
+ # self.model.eval()
+ for param in self.model.parameters():
+ param.requires_grad = False
+
+ def build_vgg_layers(self):
+ vgg_pretrained_features = self.model.features
+ self.features = []
+ # feature_layers = [0, 3, 8, 17, 26, 35]
+ feature_layers = [0, 8, 17, 26, 35]
+ for i in range(len(feature_layers)-1):
+ module_layers = torch.nn.Sequential()
+ for j in range(feature_layers[i], feature_layers[i+1]):
+ module_layers.add_module(str(j), vgg_pretrained_features[j])
+ self.features.append(module_layers)
+ self.features = torch.nn.ModuleList(self.features)
+
+ def preprocess(self, x):
+ x = (x + 1) / 2
+ x = (x - self.RGB_mean) / self.RGB_std
+ if x.shape[3] < 224:
+ x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
+ return x
+
+ def forward(self, x):
+ x = self.preprocess(x)
+ features = []
+ for m in self.features:
+ # print(m)
+ x = m(x)
+ features.append(x)
+ return features
+
+def compute_sum(x, axis=None, keepdim=False):
+ if not axis:
+ axis = range(len(x.shape))
+ for i in sorted(axis, reverse=True):
+ x = torch.sum(x, dim=i, keepdim=keepdim)
+ return x
+def ToRGB(in_channel):
+ return nn.Sequential(
+ SpectralNorm(nn.Conv2d(in_channel,in_channel,3, 1, 1)),
+ nn.LeakyReLU(0.2),
+ SpectralNorm(nn.Conv2d(in_channel,3,3, 1, 1))
+ )
+
+def AttentionBlock(in_channel):
+ return nn.Sequential(
+ SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)),
+ nn.LeakyReLU(0.2),
+ SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1))
+ )
+
+class UNetDictFace(nn.Module):
+ def __init__(self, ngf=64, dictionary_path='./DictionaryCenter512'):
+ super().__init__()
+
+ self.part_sizes = np.array([80,80,50,110]) # size for 512
+ self.feature_sizes = np.array([256,128,64,32])
+ self.channel_sizes = np.array([128,256,512,512])
+ Parts = ['left_eye','right_eye','nose','mouth']
+ self.Dict_256 = {}
+ self.Dict_128 = {}
+ self.Dict_64 = {}
+ self.Dict_32 = {}
+ for j,i in enumerate(Parts):
+ f_256 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_256_center.npy'.format(i)), allow_pickle=True))
+
+ f_256_reshape = f_256.reshape(f_256.size(0),self.channel_sizes[0],self.part_sizes[j]//2,self.part_sizes[j]//2)
+ max_256 = torch.max(torch.sqrt(compute_sum(torch.pow(f_256_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4]))
+ self.Dict_256[i] = f_256_reshape #/ max_256
+
+ f_128 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_128_center.npy'.format(i)), allow_pickle=True))
+
+ f_128_reshape = f_128.reshape(f_128.size(0),self.channel_sizes[1],self.part_sizes[j]//4,self.part_sizes[j]//4)
+ max_128 = torch.max(torch.sqrt(compute_sum(torch.pow(f_128_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4]))
+ self.Dict_128[i] = f_128_reshape #/ max_128
+
+ f_64 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_64_center.npy'.format(i)), allow_pickle=True))
+
+ f_64_reshape = f_64.reshape(f_64.size(0),self.channel_sizes[2],self.part_sizes[j]//8,self.part_sizes[j]//8)
+ max_64 = torch.max(torch.sqrt(compute_sum(torch.pow(f_64_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4]))
+ self.Dict_64[i] = f_64_reshape #/ max_64
+
+ f_32 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_32_center.npy'.format(i)), allow_pickle=True))
+
+ f_32_reshape = f_32.reshape(f_32.size(0),self.channel_sizes[3],self.part_sizes[j]//16,self.part_sizes[j]//16)
+ max_32 = torch.max(torch.sqrt(compute_sum(torch.pow(f_32_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4]))
+ self.Dict_32[i] = f_32_reshape #/ max_32
+
+ self.le_256 = AttentionBlock(128)
+ self.le_128 = AttentionBlock(256)
+ self.le_64 = AttentionBlock(512)
+ self.le_32 = AttentionBlock(512)
+
+ self.re_256 = AttentionBlock(128)
+ self.re_128 = AttentionBlock(256)
+ self.re_64 = AttentionBlock(512)
+ self.re_32 = AttentionBlock(512)
+
+ self.no_256 = AttentionBlock(128)
+ self.no_128 = AttentionBlock(256)
+ self.no_64 = AttentionBlock(512)
+ self.no_32 = AttentionBlock(512)
+
+ self.mo_256 = AttentionBlock(128)
+ self.mo_128 = AttentionBlock(256)
+ self.mo_64 = AttentionBlock(512)
+ self.mo_32 = AttentionBlock(512)
+
+ #norm
+ self.VggExtract = VGGFeat()
+
+ ######################
+ self.MSDilate = MSDilateBlock(ngf*8, dilation = [4,3,2,1]) #
+
+ self.up0 = StyledUpBlock(ngf*8,ngf*8)
+ self.up1 = StyledUpBlock(ngf*8, ngf*4) #
+ self.up2 = StyledUpBlock(ngf*4, ngf*2) #
+ self.up3 = StyledUpBlock(ngf*2, ngf) #
+ self.up4 = nn.Sequential( # 128
+ # nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
+ SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
+ # nn.BatchNorm2d(32),
+ nn.LeakyReLU(0.2),
+ UpResBlock(ngf),
+ UpResBlock(ngf),
+ # SpectralNorm(nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)),
+ nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1),
+ nn.Tanh()
+ )
+ self.to_rgb0 = ToRGB(ngf*8)
+ self.to_rgb1 = ToRGB(ngf*4)
+ self.to_rgb2 = ToRGB(ngf*2)
+ self.to_rgb3 = ToRGB(ngf*1)
+
+ # for param in self.BlurInputConv.parameters():
+ # param.requires_grad = False
+
+ def forward(self,input, part_locations):
+
+ VggFeatures = self.VggExtract(input)
+ # for b in range(input.size(0)):
+ b = 0
+ UpdateVggFeatures = []
+ for i, f_size in enumerate(self.feature_sizes):
+ cur_feature = VggFeatures[i]
+ update_feature = cur_feature.clone() #* 0
+ cur_part_sizes = self.part_sizes // (512/f_size)
+
+ dicts_feature = getattr(self, 'Dict_'+str(f_size))
+ LE_Dict_feature = dicts_feature['left_eye'].to(input)
+ RE_Dict_feature = dicts_feature['right_eye'].to(input)
+ NO_Dict_feature = dicts_feature['nose'].to(input)
+ MO_Dict_feature = dicts_feature['mouth'].to(input)
+
+ le_location = (part_locations[0][b] // (512/f_size)).int()
+ re_location = (part_locations[1][b] // (512/f_size)).int()
+ no_location = (part_locations[2][b] // (512/f_size)).int()
+ mo_location = (part_locations[3][b] // (512/f_size)).int()
+
+ LE_feature = cur_feature[:,:,le_location[1]:le_location[3],le_location[0]:le_location[2]].clone()
+ RE_feature = cur_feature[:,:,re_location[1]:re_location[3],re_location[0]:re_location[2]].clone()
+ NO_feature = cur_feature[:,:,no_location[1]:no_location[3],no_location[0]:no_location[2]].clone()
+ MO_feature = cur_feature[:,:,mo_location[1]:mo_location[3],mo_location[0]:mo_location[2]].clone()
+
+ #resize
+ LE_feature_resize = F.interpolate(LE_feature,(LE_Dict_feature.size(2),LE_Dict_feature.size(3)),mode='bilinear',align_corners=False)
+ RE_feature_resize = F.interpolate(RE_feature,(RE_Dict_feature.size(2),RE_Dict_feature.size(3)),mode='bilinear',align_corners=False)
+ NO_feature_resize = F.interpolate(NO_feature,(NO_Dict_feature.size(2),NO_Dict_feature.size(3)),mode='bilinear',align_corners=False)
+ MO_feature_resize = F.interpolate(MO_feature,(MO_Dict_feature.size(2),MO_Dict_feature.size(3)),mode='bilinear',align_corners=False)
+
+ LE_Dict_feature_norm = adaptive_instance_normalization_4D(LE_Dict_feature, LE_feature_resize)
+ RE_Dict_feature_norm = adaptive_instance_normalization_4D(RE_Dict_feature, RE_feature_resize)
+ NO_Dict_feature_norm = adaptive_instance_normalization_4D(NO_Dict_feature, NO_feature_resize)
+ MO_Dict_feature_norm = adaptive_instance_normalization_4D(MO_Dict_feature, MO_feature_resize)
+
+ LE_score = F.conv2d(LE_feature_resize, LE_Dict_feature_norm)
+
+ LE_score = F.softmax(LE_score.view(-1),dim=0)
+ LE_index = torch.argmax(LE_score)
+ LE_Swap_feature = F.interpolate(LE_Dict_feature_norm[LE_index:LE_index+1], (LE_feature.size(2), LE_feature.size(3)))
+
+ LE_Attention = getattr(self, 'le_'+str(f_size))(LE_Swap_feature-LE_feature)
+ LE_Att_feature = LE_Attention * LE_Swap_feature
+
+
+ RE_score = F.conv2d(RE_feature_resize, RE_Dict_feature_norm)
+ RE_score = F.softmax(RE_score.view(-1),dim=0)
+ RE_index = torch.argmax(RE_score)
+ RE_Swap_feature = F.interpolate(RE_Dict_feature_norm[RE_index:RE_index+1], (RE_feature.size(2), RE_feature.size(3)))
+
+ RE_Attention = getattr(self, 're_'+str(f_size))(RE_Swap_feature-RE_feature)
+ RE_Att_feature = RE_Attention * RE_Swap_feature
+
+ NO_score = F.conv2d(NO_feature_resize, NO_Dict_feature_norm)
+ NO_score = F.softmax(NO_score.view(-1),dim=0)
+ NO_index = torch.argmax(NO_score)
+ NO_Swap_feature = F.interpolate(NO_Dict_feature_norm[NO_index:NO_index+1], (NO_feature.size(2), NO_feature.size(3)))
+
+ NO_Attention = getattr(self, 'no_'+str(f_size))(NO_Swap_feature-NO_feature)
+ NO_Att_feature = NO_Attention * NO_Swap_feature
+
+
+ MO_score = F.conv2d(MO_feature_resize, MO_Dict_feature_norm)
+ MO_score = F.softmax(MO_score.view(-1),dim=0)
+ MO_index = torch.argmax(MO_score)
+ MO_Swap_feature = F.interpolate(MO_Dict_feature_norm[MO_index:MO_index+1], (MO_feature.size(2), MO_feature.size(3)))
+
+ MO_Attention = getattr(self, 'mo_'+str(f_size))(MO_Swap_feature-MO_feature)
+ MO_Att_feature = MO_Attention * MO_Swap_feature
+
+ update_feature[:,:,le_location[1]:le_location[3],le_location[0]:le_location[2]] = LE_Att_feature + LE_feature
+ update_feature[:,:,re_location[1]:re_location[3],re_location[0]:re_location[2]] = RE_Att_feature + RE_feature
+ update_feature[:,:,no_location[1]:no_location[3],no_location[0]:no_location[2]] = NO_Att_feature + NO_feature
+ update_feature[:,:,mo_location[1]:mo_location[3],mo_location[0]:mo_location[2]] = MO_Att_feature + MO_feature
+
+ UpdateVggFeatures.append(update_feature)
+
+ fea_vgg = self.MSDilate(VggFeatures[3])
+ #new version
+ fea_up0 = self.up0(fea_vgg, UpdateVggFeatures[3])
+ # out1 = F.interpolate(fea_up0,(512,512))
+ # out1 = self.to_rgb0(out1)
+
+ fea_up1 = self.up1( fea_up0, UpdateVggFeatures[2]) #
+ # out2 = F.interpolate(fea_up1,(512,512))
+ # out2 = self.to_rgb1(out2)
+
+ fea_up2 = self.up2(fea_up1, UpdateVggFeatures[1]) #
+ # out3 = F.interpolate(fea_up2,(512,512))
+ # out3 = self.to_rgb2(out3)
+
+ fea_up3 = self.up3(fea_up2, UpdateVggFeatures[0]) #
+ # out4 = F.interpolate(fea_up3,(512,512))
+ # out4 = self.to_rgb3(out4)
+
+ output = self.up4(fea_up3) #
+
+
+ return output #+ out4 + out3 + out2 + out1
+ #0 128 * 256 * 256
+ #1 256 * 128 * 128
+ #2 512 * 64 * 64
+ #3 512 * 32 * 32
+
+
+class UpResBlock(nn.Module):
+ def __init__(self, dim, conv_layer = nn.Conv2d, norm_layer = nn.BatchNorm2d):
+ super(UpResBlock, self).__init__()
+ self.Model = nn.Sequential(
+ # SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
+ conv_layer(dim, dim, 3, 1, 1),
+ # norm_layer(dim),
+ nn.LeakyReLU(0.2,True),
+ # SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
+ conv_layer(dim, dim, 3, 1, 1),
+ )
+ def forward(self, x):
+ out = x + self.Model(x)
+ return out
+
+class VggClassNet(nn.Module):
+ def __init__(self, select_layer = ['0','5','10','19']):
+ super(VggClassNet, self).__init__()
+ self.select = select_layer
+ self.vgg = models.vgg19(pretrained=True).features
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, x):
+ features = []
+ for name, layer in self.vgg._modules.items():
+ x = layer(x)
+ if name in self.select:
+ features.append(x)
+ return features
+
+
+if __name__ == '__main__':
+ print('this is network')
+
+
diff --git a/models/test_model.py b/models/test_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae702e5b9cc19eb36c208d72ab5e07c6c38f7eca
--- /dev/null
+++ b/models/test_model.py
@@ -0,0 +1,48 @@
+from .base_model import BaseModel
+from . import networks
+import torch
+import numpy as np
+import torchvision.transforms as transforms
+import PIL
+
+import torch.nn.functional as F
+
+class TestModel(BaseModel):
+ def name(self):
+ return 'TestModel'
+
+ @staticmethod
+ def modify_commandline_options(parser, is_train=True):
+ assert not is_train, 'TestModel cannot be used in train mode'
+ parser.set_defaults(dataset_mode='aligned')
+
+ parser.add_argument('--model_suffix', type=str, default='',
+ help='In checkpoints_dir, [which_epoch]_net_G[model_suffix].pth will'
+ ' be loaded as the generator of TestModel')
+ return parser
+
+ def initialize(self, opt):
+ assert(not opt.isTrain)
+ BaseModel.initialize(self, opt)
+
+ # specify the training losses you want to print out. The program will call base_model.get_current_losses
+ self.loss_names = []
+ # specify the images you want to save/display. The program will call base_model.get_current_visuals
+ self.visual_names = ['fake_A','real_A']
+ self.model_names = ['G']
+
+ self.netG = networks.define_G('UNetDictFace',self.gpu_ids)
+
+ def set_input(self, input):
+ self.real_A = input['A'].to(self.device) #degraded img
+ self.real_C = input['C'].to(self.device) #groundtruth
+ self.image_paths = input['A_paths']
+ self.Part_locations = input['Part_locations']
+
+ def forward(self):
+
+ self.fake_A = self.netG(self.real_A, self.Part_locations) #
+ # try:
+ # self.fake_A = self.netG(self.real_A, self.Part_locations) #生成图
+ # except:
+ # self.fake_A = self.real_A
diff --git a/options/__pycache__/base_options.cpython-38.pyc b/options/__pycache__/base_options.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d3f35f849babac414bc3258bb84209d8febe391b
Binary files /dev/null and b/options/__pycache__/base_options.cpython-38.pyc differ
diff --git a/options/__pycache__/test_options.cpython-38.pyc b/options/__pycache__/test_options.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d4164c89c22506a57a45f76d195bdf290850137a
Binary files /dev/null and b/options/__pycache__/test_options.cpython-38.pyc differ
diff --git a/options/base_options.py b/options/base_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad99c47f7a17928280899abc138b1c9ff436e693
--- /dev/null
+++ b/options/base_options.py
@@ -0,0 +1,103 @@
+import argparse
+import os
+from util import util
+import torch
+import models
+import data
+
+
+class BaseOptions():
+ def __init__(self):
+ self.initialized = False
+
+ def initialize(self, parser):
+
+ parser.add_argument('--batchSize', type=int, default=2, help='input batch size')
+ parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
+ parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
+ parser.add_argument('--name', type=str, default='facefh_dictionary', help='name of the experiment. It decides where to store samples and models')
+ parser.add_argument('--model', type=str, default='faceDict', help='chooses which model to use. cycle_gan, pix2pix, test')
+ parser.add_argument('--which_direction', type=str, default='BtoA', help='AtoB or BtoA')
+ parser.add_argument('--nThreads', default=8, type=int, help='# threads for loading data')
+ parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
+ parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
+ parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
+ parser.add_argument('--resize_or_crop', type=str, default='degradation', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
+ parser.add_argument('--init_type', type=str, default='kaiming', help='network initialization [normal|xavier|kaiming|orthogonal]')
+ parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
+ parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
+ parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{which_model_netG}_size{loadSize}')
+ self.initialized = True
+ return parser
+
+ def gather_options(self):
+ # initialize parser with basic options
+ if not self.initialized:
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser = self.initialize(parser)
+
+ # get the basic options
+
+ opt, _ = parser.parse_known_args()
+ # modify model-related parser options
+ model_name = opt.model
+ model_option_setter = models.get_option_setter(model_name)
+ parser = model_option_setter(parser, self.isTrain)
+
+ opt, _ = parser.parse_known_args() # parse again with the new defaults
+
+ # modify dataset-related parser options
+ dataset_name = opt.dataset_mode
+
+ dataset_option_setter = data.get_option_setter(dataset_name)
+ parser = dataset_option_setter(parser, self.isTrain)
+
+ self.parser = parser
+
+ return parser.parse_args()
+
+ def print_options(self, opt):
+ message = ''
+ message += '----------------- Options ---------------\n'
+ for k, v in sorted(vars(opt).items()):
+ comment = ''
+ default = self.parser.get_default(k)
+ if v != default:
+ comment = '\t[default: %s]' % str(default)
+ message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
+ message += '----------------- End -------------------'
+ print(message)
+
+ # save to the disk
+ expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
+ util.mkdirs(expr_dir)
+ file_name = os.path.join(expr_dir, 'opt.txt')
+ with open(file_name, 'wt') as opt_file:
+ opt_file.write(message)
+ opt_file.write('\n')
+
+ def parse(self):
+
+ opt = self.gather_options()
+ opt.isTrain = self.isTrain # train or test
+
+ # process opt.suffix
+ if opt.suffix:
+ suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
+ opt.name = opt.name + suffix
+
+ # self.print_options(opt)
+
+ # set gpu ids
+ str_ids = opt.gpu_ids.split(',')
+ opt.gpu_ids = []
+ for str_id in str_ids:
+ id = int(str_id)
+ if id >= 0:
+ opt.gpu_ids.append(id)
+ if len(opt.gpu_ids) > 0:
+ torch.cuda.set_device(opt.gpu_ids[0])
+
+ self.opt = opt
+ return self.opt
diff --git a/options/test_options.py b/options/test_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f0eeb902981c81da164b7c1a82e6f57ef75722f
--- /dev/null
+++ b/options/test_options.py
@@ -0,0 +1,20 @@
+from .base_options import BaseOptions
+
+
+class TestOptions(BaseOptions):
+ def initialize(self, parser):
+ parser = BaseOptions.initialize(self, parser)
+ # parser.add_argument('--dataroot', type=str, default='/home/Data/AllDataImages/2018_FaceFH', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
+ parser.add_argument('--dataroot', type=str, default='', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
+ parser.add_argument('--phase', type=str, default='', help='train, val, test, etc')
+ parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
+ parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
+ parser.set_defaults(model='test')
+
+ parser.add_argument('--test_path', type=str, default='./TestData/TestWhole', help='test images path')
+ parser.add_argument('--results_dir', type=str, default='./Results/TestWholeResults', help='saves results here.')
+ parser.add_argument('--upscale_factor', type=int, default=4, help='upscale factor for the whole input image (not for face)')
+
+
+ self.isTrain = False
+ return parser
diff --git a/packages/FFHQ_template.npy b/packages/FFHQ_template.npy
new file mode 100644
index 0000000000000000000000000000000000000000..a3cf332405b244e0b55c93c599e7d7753cf52d83
Binary files /dev/null and b/packages/FFHQ_template.npy differ
diff --git a/packages/mmod_human_face_detector.dat b/packages/mmod_human_face_detector.dat
new file mode 100644
index 0000000000000000000000000000000000000000..f112a0a45dbda96080352c45f615d00ddd4f130c
Binary files /dev/null and b/packages/mmod_human_face_detector.dat differ
diff --git a/packages/shape_predictor_5_face_landmarks.dat b/packages/shape_predictor_5_face_landmarks.dat
new file mode 100644
index 0000000000000000000000000000000000000000..67878ed3894d929c5e03bd1ad2d931cc6d745ee7
Binary files /dev/null and b/packages/shape_predictor_5_face_landmarks.dat differ
diff --git a/sync_batchnorm/__init__.py b/sync_batchnorm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a10989fbef38485841fdbd1af72f53062a7b556d
--- /dev/null
+++ b/sync_batchnorm/__init__.py
@@ -0,0 +1,13 @@
+# -*- coding: utf-8 -*-
+# File : __init__.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
+from .batchnorm import patch_sync_batchnorm, convert_model
+from .replicate import DataParallelWithCallback, patch_replication_callback
diff --git a/sync_batchnorm/batchnorm.py b/sync_batchnorm/batchnorm.py
new file mode 100644
index 0000000000000000000000000000000000000000..24c0950707d07102d06b5fdf64c74a347ec3186c
--- /dev/null
+++ b/sync_batchnorm/batchnorm.py
@@ -0,0 +1,394 @@
+# -*- coding: utf-8 -*-
+# File : batchnorm.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import collections
+import contextlib
+
+import torch
+import torch.nn.functional as F
+
+from torch.nn.modules.batchnorm import _BatchNorm
+
+try:
+ from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
+except ImportError:
+ ReduceAddCoalesced = Broadcast = None
+
+try:
+ from jactorch.parallel.comm import SyncMaster
+ from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback
+except ImportError:
+ from .comm import SyncMaster
+ from .replicate import DataParallelWithCallback
+
+__all__ = [
+ 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d',
+ 'patch_sync_batchnorm', 'convert_model'
+]
+
+
+def _sum_ft(tensor):
+ """sum over the first and last dimention"""
+ return tensor.sum(dim=0).sum(dim=-1)
+
+
+def _unsqueeze_ft(tensor):
+ """add new dimensions at the front and the tail"""
+ return tensor.unsqueeze(0).unsqueeze(-1)
+
+
+_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
+_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
+
+
+class _SynchronizedBatchNorm(_BatchNorm):
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
+ assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.'
+
+ super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
+
+ self._sync_master = SyncMaster(self._data_parallel_master)
+
+ self._is_parallel = False
+ self._parallel_id = None
+ self._slave_pipe = None
+
+ def forward(self, input):
+ # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
+ if not (self._is_parallel and self.training):
+ return F.batch_norm(
+ input, self.running_mean, self.running_var, self.weight, self.bias,
+ self.training, self.momentum, self.eps)
+
+ # Resize the input to (B, C, -1).
+ input_shape = input.size()
+ input = input.view(input.size(0), self.num_features, -1)
+
+ # Compute the sum and square-sum.
+ sum_size = input.size(0) * input.size(2)
+ input_sum = _sum_ft(input)
+ input_ssum = _sum_ft(input ** 2)
+
+ # Reduce-and-broadcast the statistics.
+ if self._parallel_id == 0:
+ mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
+ else:
+ mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
+
+ # Compute the output.
+ if self.affine:
+ # MJY:: Fuse the multiplication for speed.
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
+ else:
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
+
+ # Reshape it.
+ return output.view(input_shape)
+
+ def __data_parallel_replicate__(self, ctx, copy_id):
+ self._is_parallel = True
+ self._parallel_id = copy_id
+
+ # parallel_id == 0 means master device.
+ if self._parallel_id == 0:
+ ctx.sync_master = self._sync_master
+ else:
+ self._slave_pipe = ctx.sync_master.register_slave(copy_id)
+
+ def _data_parallel_master(self, intermediates):
+ """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
+
+ # Always using same "device order" makes the ReduceAdd operation faster.
+ # Thanks to:: Tete Xiao (http://tetexiao.com/)
+ intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
+
+ to_reduce = [i[1][:2] for i in intermediates]
+ to_reduce = [j for i in to_reduce for j in i] # flatten
+ target_gpus = [i[1].sum.get_device() for i in intermediates]
+
+ sum_size = sum([i[1].sum_size for i in intermediates])
+ sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
+ mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
+
+ broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
+
+ outputs = []
+ for i, rec in enumerate(intermediates):
+ outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
+
+ return outputs
+
+ def _compute_mean_std(self, sum_, ssum, size):
+ """Compute the mean and standard-deviation with sum and square-sum. This method
+ also maintains the moving average on the master device."""
+ assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
+ mean = sum_ / size
+ sumvar = ssum - sum_ * mean
+ unbias_var = sumvar / (size - 1)
+ bias_var = sumvar / size
+
+ if hasattr(torch, 'no_grad'):
+ with torch.no_grad():
+ self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
+ self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
+ else:
+ self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
+ self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
+
+ return mean, bias_var.clamp(self.eps) ** -0.5
+
+
+class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
+ r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
+ mini-batch.
+
+ .. math::
+
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
+
+ This module differs from the built-in PyTorch BatchNorm1d as the mean and
+ standard-deviation are reduced across all devices during training.
+
+ For example, when one uses `nn.DataParallel` to wrap the network during
+ training, PyTorch's implementation normalize the tensor on each device using
+ the statistics only on that device, which accelerated the computation and
+ is also easy to implement, but the statistics might be inaccurate.
+ Instead, in this synchronized version, the statistics will be computed
+ over all training samples distributed on multiple devices.
+
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
+ as the built-in PyTorch implementation.
+
+ The mean and standard-deviation are calculated per-dimension over
+ the mini-batches and gamma and beta are learnable parameter vectors
+ of size C (where C is the input size).
+
+ During training, this layer keeps a running estimate of its computed mean
+ and variance. The running sum is kept with a default momentum of 0.1.
+
+ During evaluation, this running mean/variance is used for normalization.
+
+ Because the BatchNorm is done over the `C` dimension, computing statistics
+ on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
+
+ Args:
+ num_features: num_features from an expected input of size
+ `batch_size x num_features [x width]`
+ eps: a value added to the denominator for numerical stability.
+ Default: 1e-5
+ momentum: the value used for the running_mean and running_var
+ computation. Default: 0.1
+ affine: a boolean value that when set to ``True``, gives the layer learnable
+ affine parameters. Default: ``True``
+
+ Shape::
+ - Input: :math:`(N, C)` or :math:`(N, C, L)`
+ - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
+
+ Examples:
+ >>> # With Learnable Parameters
+ >>> m = SynchronizedBatchNorm1d(100)
+ >>> # Without Learnable Parameters
+ >>> m = SynchronizedBatchNorm1d(100, affine=False)
+ >>> input = torch.autograd.Variable(torch.randn(20, 100))
+ >>> output = m(input)
+ """
+
+ def _check_input_dim(self, input):
+ if input.dim() != 2 and input.dim() != 3:
+ raise ValueError('expected 2D or 3D input (got {}D input)'
+ .format(input.dim()))
+ super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
+
+
+class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
+ r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
+ of 3d inputs
+
+ .. math::
+
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
+
+ This module differs from the built-in PyTorch BatchNorm2d as the mean and
+ standard-deviation are reduced across all devices during training.
+
+ For example, when one uses `nn.DataParallel` to wrap the network during
+ training, PyTorch's implementation normalize the tensor on each device using
+ the statistics only on that device, which accelerated the computation and
+ is also easy to implement, but the statistics might be inaccurate.
+ Instead, in this synchronized version, the statistics will be computed
+ over all training samples distributed on multiple devices.
+
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
+ as the built-in PyTorch implementation.
+
+ The mean and standard-deviation are calculated per-dimension over
+ the mini-batches and gamma and beta are learnable parameter vectors
+ of size C (where C is the input size).
+
+ During training, this layer keeps a running estimate of its computed mean
+ and variance. The running sum is kept with a default momentum of 0.1.
+
+ During evaluation, this running mean/variance is used for normalization.
+
+ Because the BatchNorm is done over the `C` dimension, computing statistics
+ on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
+
+ Args:
+ num_features: num_features from an expected input of
+ size batch_size x num_features x height x width
+ eps: a value added to the denominator for numerical stability.
+ Default: 1e-5
+ momentum: the value used for the running_mean and running_var
+ computation. Default: 0.1
+ affine: a boolean value that when set to ``True``, gives the layer learnable
+ affine parameters. Default: ``True``
+
+ Shape::
+ - Input: :math:`(N, C, H, W)`
+ - Output: :math:`(N, C, H, W)` (same shape as input)
+
+ Examples:
+ >>> # With Learnable Parameters
+ >>> m = SynchronizedBatchNorm2d(100)
+ >>> # Without Learnable Parameters
+ >>> m = SynchronizedBatchNorm2d(100, affine=False)
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
+ >>> output = m(input)
+ """
+
+ def _check_input_dim(self, input):
+ if input.dim() != 4:
+ raise ValueError('expected 4D input (got {}D input)'
+ .format(input.dim()))
+ super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
+
+
+class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
+ r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
+ of 4d inputs
+
+ .. math::
+
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
+
+ This module differs from the built-in PyTorch BatchNorm3d as the mean and
+ standard-deviation are reduced across all devices during training.
+
+ For example, when one uses `nn.DataParallel` to wrap the network during
+ training, PyTorch's implementation normalize the tensor on each device using
+ the statistics only on that device, which accelerated the computation and
+ is also easy to implement, but the statistics might be inaccurate.
+ Instead, in this synchronized version, the statistics will be computed
+ over all training samples distributed on multiple devices.
+
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
+ as the built-in PyTorch implementation.
+
+ The mean and standard-deviation are calculated per-dimension over
+ the mini-batches and gamma and beta are learnable parameter vectors
+ of size C (where C is the input size).
+
+ During training, this layer keeps a running estimate of its computed mean
+ and variance. The running sum is kept with a default momentum of 0.1.
+
+ During evaluation, this running mean/variance is used for normalization.
+
+ Because the BatchNorm is done over the `C` dimension, computing statistics
+ on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
+ or Spatio-temporal BatchNorm
+
+ Args:
+ num_features: num_features from an expected input of
+ size batch_size x num_features x depth x height x width
+ eps: a value added to the denominator for numerical stability.
+ Default: 1e-5
+ momentum: the value used for the running_mean and running_var
+ computation. Default: 0.1
+ affine: a boolean value that when set to ``True``, gives the layer learnable
+ affine parameters. Default: ``True``
+
+ Shape::
+ - Input: :math:`(N, C, D, H, W)`
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
+
+ Examples:
+ >>> # With Learnable Parameters
+ >>> m = SynchronizedBatchNorm3d(100)
+ >>> # Without Learnable Parameters
+ >>> m = SynchronizedBatchNorm3d(100, affine=False)
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
+ >>> output = m(input)
+ """
+
+ def _check_input_dim(self, input):
+ if input.dim() != 5:
+ raise ValueError('expected 5D input (got {}D input)'
+ .format(input.dim()))
+ super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
+
+
+@contextlib.contextmanager
+def patch_sync_batchnorm():
+ import torch.nn as nn
+
+ backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
+
+ nn.BatchNorm1d = SynchronizedBatchNorm1d
+ nn.BatchNorm2d = SynchronizedBatchNorm2d
+ nn.BatchNorm3d = SynchronizedBatchNorm3d
+
+ yield
+
+ nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup
+
+
+def convert_model(module):
+ """Traverse the input module and its child recursively
+ and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
+ to SynchronizedBatchNorm*N*d
+
+ Args:
+ module: the input module needs to be convert to SyncBN model
+
+ Examples:
+ >>> import torch.nn as nn
+ >>> import torchvision
+ >>> # m is a standard pytorch model
+ >>> m = torchvision.models.resnet18(True)
+ >>> m = nn.DataParallel(m)
+ >>> # after convert, m is using SyncBN
+ >>> m = convert_model(m)
+ """
+ if isinstance(module, torch.nn.DataParallel):
+ mod = module.module
+ mod = convert_model(mod)
+ mod = DataParallelWithCallback(mod)
+ return mod
+
+ mod = module
+ for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
+ torch.nn.modules.batchnorm.BatchNorm2d,
+ torch.nn.modules.batchnorm.BatchNorm3d],
+ [SynchronizedBatchNorm1d,
+ SynchronizedBatchNorm2d,
+ SynchronizedBatchNorm3d]):
+ if isinstance(module, pth_module):
+ mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)
+ mod.running_mean = module.running_mean
+ mod.running_var = module.running_var
+ if module.affine:
+ mod.weight.data = module.weight.data.clone().detach()
+ mod.bias.data = module.bias.data.clone().detach()
+
+ for name, child in module.named_children():
+ mod.add_module(name, convert_model(child))
+
+ return mod
diff --git a/sync_batchnorm/batchnorm_reimpl.py b/sync_batchnorm/batchnorm_reimpl.py
new file mode 100644
index 0000000000000000000000000000000000000000..18145c3353e13d482c492ae46df91a537669fca0
--- /dev/null
+++ b/sync_batchnorm/batchnorm_reimpl.py
@@ -0,0 +1,74 @@
+#! /usr/bin/env python3
+# -*- coding: utf-8 -*-
+# File : batchnorm_reimpl.py
+# Author : acgtyrant
+# Date : 11/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import torch
+import torch.nn as nn
+import torch.nn.init as init
+
+__all__ = ['BatchNorm2dReimpl']
+
+
+class BatchNorm2dReimpl(nn.Module):
+ """
+ A re-implementation of batch normalization, used for testing the numerical
+ stability.
+
+ Author: acgtyrant
+ See also:
+ https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
+ """
+ def __init__(self, num_features, eps=1e-5, momentum=0.1):
+ super().__init__()
+
+ self.num_features = num_features
+ self.eps = eps
+ self.momentum = momentum
+ self.weight = nn.Parameter(torch.empty(num_features))
+ self.bias = nn.Parameter(torch.empty(num_features))
+ self.register_buffer('running_mean', torch.zeros(num_features))
+ self.register_buffer('running_var', torch.ones(num_features))
+ self.reset_parameters()
+
+ def reset_running_stats(self):
+ self.running_mean.zero_()
+ self.running_var.fill_(1)
+
+ def reset_parameters(self):
+ self.reset_running_stats()
+ init.uniform_(self.weight)
+ init.zeros_(self.bias)
+
+ def forward(self, input_):
+ batchsize, channels, height, width = input_.size()
+ numel = batchsize * height * width
+ input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
+ sum_ = input_.sum(1)
+ sum_of_square = input_.pow(2).sum(1)
+ mean = sum_ / numel
+ sumvar = sum_of_square - sum_ * mean
+
+ self.running_mean = (
+ (1 - self.momentum) * self.running_mean
+ + self.momentum * mean.detach()
+ )
+ unbias_var = sumvar / (numel - 1)
+ self.running_var = (
+ (1 - self.momentum) * self.running_var
+ + self.momentum * unbias_var.detach()
+ )
+
+ bias_var = sumvar / numel
+ inv_std = 1 / (bias_var + self.eps).pow(0.5)
+ output = (
+ (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
+ self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
+
+ return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
+
diff --git a/sync_batchnorm/comm.py b/sync_batchnorm/comm.py
new file mode 100644
index 0000000000000000000000000000000000000000..922f8c4a3adaa9b32fdcaef09583be03b0d7eb2b
--- /dev/null
+++ b/sync_batchnorm/comm.py
@@ -0,0 +1,137 @@
+# -*- coding: utf-8 -*-
+# File : comm.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import queue
+import collections
+import threading
+
+__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
+
+
+class FutureResult(object):
+ """A thread-safe future implementation. Used only as one-to-one pipe."""
+
+ def __init__(self):
+ self._result = None
+ self._lock = threading.Lock()
+ self._cond = threading.Condition(self._lock)
+
+ def put(self, result):
+ with self._lock:
+ assert self._result is None, 'Previous result has\'t been fetched.'
+ self._result = result
+ self._cond.notify()
+
+ def get(self):
+ with self._lock:
+ if self._result is None:
+ self._cond.wait()
+
+ res = self._result
+ self._result = None
+ return res
+
+
+_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
+_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
+
+
+class SlavePipe(_SlavePipeBase):
+ """Pipe for master-slave communication."""
+
+ def run_slave(self, msg):
+ self.queue.put((self.identifier, msg))
+ ret = self.result.get()
+ self.queue.put(True)
+ return ret
+
+
+class SyncMaster(object):
+ """An abstract `SyncMaster` object.
+
+ - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
+ call `register(id)` and obtain an `SlavePipe` to communicate with the master.
+ - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
+ and passed to a registered callback.
+ - After receiving the messages, the master device should gather the information and determine to message passed
+ back to each slave devices.
+ """
+
+ def __init__(self, master_callback):
+ """
+
+ Args:
+ master_callback: a callback to be invoked after having collected messages from slave devices.
+ """
+ self._master_callback = master_callback
+ self._queue = queue.Queue()
+ self._registry = collections.OrderedDict()
+ self._activated = False
+
+ def __getstate__(self):
+ return {'master_callback': self._master_callback}
+
+ def __setstate__(self, state):
+ self.__init__(state['master_callback'])
+
+ def register_slave(self, identifier):
+ """
+ Register an slave device.
+
+ Args:
+ identifier: an identifier, usually is the device id.
+
+ Returns: a `SlavePipe` object which can be used to communicate with the master device.
+
+ """
+ if self._activated:
+ assert self._queue.empty(), 'Queue is not clean before next initialization.'
+ self._activated = False
+ self._registry.clear()
+ future = FutureResult()
+ self._registry[identifier] = _MasterRegistry(future)
+ return SlavePipe(identifier, self._queue, future)
+
+ def run_master(self, master_msg):
+ """
+ Main entry for the master device in each forward pass.
+ The messages were first collected from each devices (including the master device), and then
+ an callback will be invoked to compute the message to be sent back to each devices
+ (including the master device).
+
+ Args:
+ master_msg: the message that the master want to send to itself. This will be placed as the first
+ message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
+
+ Returns: the message to be sent back to the master device.
+
+ """
+ self._activated = True
+
+ intermediates = [(0, master_msg)]
+ for i in range(self.nr_slaves):
+ intermediates.append(self._queue.get())
+
+ results = self._master_callback(intermediates)
+ assert results[0][0] == 0, 'The first result should belongs to the master.'
+
+ for i, res in results:
+ if i == 0:
+ continue
+ self._registry[i].result.put(res)
+
+ for i in range(self.nr_slaves):
+ assert self._queue.get() is True
+
+ return results[0][1]
+
+ @property
+ def nr_slaves(self):
+ return len(self._registry)
diff --git a/sync_batchnorm/replicate.py b/sync_batchnorm/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..b71c7b8ed51a1d6c55b1f753bdd8d90bad79bd06
--- /dev/null
+++ b/sync_batchnorm/replicate.py
@@ -0,0 +1,94 @@
+# -*- coding: utf-8 -*-
+# File : replicate.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import functools
+
+from torch.nn.parallel.data_parallel import DataParallel
+
+__all__ = [
+ 'CallbackContext',
+ 'execute_replication_callbacks',
+ 'DataParallelWithCallback',
+ 'patch_replication_callback'
+]
+
+
+class CallbackContext(object):
+ pass
+
+
+def execute_replication_callbacks(modules):
+ """
+ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
+
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
+
+ Note that, as all modules are isomorphism, we assign each sub-module with a context
+ (shared among multiple copies of this module on different devices).
+ Through this context, different copies can share some information.
+
+ We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
+ of any slave copies.
+ """
+ master_copy = modules[0]
+ nr_modules = len(list(master_copy.modules()))
+ ctxs = [CallbackContext() for _ in range(nr_modules)]
+
+ for i, module in enumerate(modules):
+ for j, m in enumerate(module.modules()):
+ if hasattr(m, '__data_parallel_replicate__'):
+ m.__data_parallel_replicate__(ctxs[j], i)
+
+
+class DataParallelWithCallback(DataParallel):
+ """
+ Data Parallel with a replication callback.
+
+ An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
+ original `replicate` function.
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
+
+ Examples:
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
+ # sync_bn.__data_parallel_replicate__ will be invoked.
+ """
+
+ def replicate(self, module, device_ids):
+ modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
+ execute_replication_callbacks(modules)
+ return modules
+
+
+def patch_replication_callback(data_parallel):
+ """
+ Monkey-patch an existing `DataParallel` object. Add the replication callback.
+ Useful when you have customized `DataParallel` implementation.
+
+ Examples:
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
+ > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
+ > patch_replication_callback(sync_bn)
+ # this is equivalent to
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
+ """
+
+ assert isinstance(data_parallel, DataParallel)
+
+ old_replicate = data_parallel.replicate
+
+ @functools.wraps(old_replicate)
+ def new_replicate(module, device_ids):
+ modules = old_replicate(module, device_ids)
+ execute_replication_callbacks(modules)
+ return modules
+
+ data_parallel.replicate = new_replicate
diff --git a/sync_batchnorm/unittest.py b/sync_batchnorm/unittest.py
new file mode 100644
index 0000000000000000000000000000000000000000..bed56f1caa929ac3e9a57c583f8d3e42624f58be
--- /dev/null
+++ b/sync_batchnorm/unittest.py
@@ -0,0 +1,29 @@
+# -*- coding: utf-8 -*-
+# File : unittest.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import unittest
+import torch
+
+
+class TorchTestCase(unittest.TestCase):
+ def assertTensorClose(self, x, y):
+ adiff = float((x - y).abs().max())
+ if (y == 0).all():
+ rdiff = 'NaN'
+ else:
+ rdiff = float((adiff / y).abs().max())
+
+ message = (
+ 'Tensor close check failed\n'
+ 'adiff={}\n'
+ 'rdiff={}\n'
+ ).format(adiff, rdiff)
+ self.assertTrue(torch.allclose(x, y), message)
+
diff --git a/util/Loss.py b/util/Loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f3d9bc582b79e34ab30e8f3ed03f4cf2bb8ccad
--- /dev/null
+++ b/util/Loss.py
@@ -0,0 +1,41 @@
+import torch.nn as nn
+import torch
+import torch.nn.functional as F
+
+
+class TVLoss(nn.Module):
+ def __init__(self,TVLoss_weight=1):
+ super(TVLoss,self).__init__()
+ self.TVLoss_weight = TVLoss_weight
+
+ def forward(self,x):
+ batch_size = x.size()[0]
+ h_x = x.size()[2]
+ w_x = x.size()[3]
+ count_h = self._tensor_size(x[:,:,1:,:])
+ count_w = self._tensor_size(x[:,:,:,1:])
+ h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
+ w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
+ return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size
+
+ def _tensor_size(self,t):
+ return t.size()[1]*t.size()[2]*t.size()[3]
+
+
+class hinge_loss(nn.Module):
+ def __init__(self):
+ super(hinge_loss, self).__init__()
+
+ def forward(self, dis_fake, dis_real):
+ loss_real = torch.mean(F.relu(1. - dis_real))
+ loss_fake = torch.mean(F.relu(1. + dis_fake))
+ return loss_real + loss_fake
+
+
+class hinge_loss_G(nn.Module):
+ def __init__(self):
+ super(hinge_loss_G, self).__init__()
+
+ def forward(self, dis_fake):
+ loss_fake = -torch.mean(dis_fake)
+ return loss_fake
\ No newline at end of file
diff --git a/util/ROI_MLS.py b/util/ROI_MLS.py
new file mode 100644
index 0000000000000000000000000000000000000000..0987909e797831c3b350150a4706016b16159bc3
--- /dev/null
+++ b/util/ROI_MLS.py
@@ -0,0 +1,542 @@
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+#mls affine inv
+
+def mls_affine_deformation_inv(image, height, width, channel, p, q, alpha=1.0, density=1.0):
+ ''' Affine inverse deformation
+ ### Params:
+ * image - ndarray: original image
+ * p - ndarray: an array with size [n, 2], original control points
+ * q - ndarray: an array with size [n, 2], final control points
+ * alpha - float: parameter used by weights
+ * density - float: density of the grids
+ ### Return:
+ A deformed image.
+ '''
+ # height = image.shape[0]
+ # width = image.shape[1]
+ # Change (x, y) to (row, col)
+ q = q[:, [1, 0]]
+ p = p[:, [1, 0]]
+
+ # Make grids on the original image
+ gridX = np.linspace(0, width, num=int(width*density), endpoint=False)
+ gridY = np.linspace(0, height, num=int(height*density), endpoint=False)
+ vy, vx = np.meshgrid(gridX, gridY)
+ grow = vx.shape[0] # grid rows
+ gcol = vx.shape[1] # grid cols
+ ctrls = p.shape[0] # control points
+
+ # Compute
+ reshaped_p = p.reshape(ctrls, 2, 1, 1) # [ctrls, 2, 1, 1]
+ reshaped_q = q.reshape((ctrls, 2, 1, 1)) # [ctrls, 2, 1, 1]
+ reshaped_v = np.vstack((vx.reshape(1, grow, gcol), vy.reshape(1, grow, gcol))) # [2, grow, gcol]
+
+ w = 1.0 / np.sum((reshaped_p - reshaped_v) ** 2, axis=1)**alpha # [ctrls, grow, gcol]
+ w[w == np.inf] = 2**31 - 1
+ pstar = np.sum(w * reshaped_p.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol]
+ phat = reshaped_p - pstar # [ctrls, 2, grow, gcol]
+ qstar = np.sum(w * reshaped_q.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol]
+ qhat = reshaped_q - qstar # [ctrls, 2, grow, gcol]
+
+ reshaped_phat = phat.reshape(ctrls, 2, 1, grow, gcol) # [ctrls, 2, 1, grow, gcol]
+ reshaped_phat2 = phat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 2, 1, grow, gcol]
+ reshaped_qhat = qhat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol]
+ reshaped_w = w.reshape(ctrls, 1, 1, grow, gcol) # [ctrls, 1, 1, grow, gcol]
+ pTwq = np.sum(reshaped_phat * reshaped_w * reshaped_qhat, axis=0) # [2, 2, grow, gcol]
+ try:
+ inv_pTwq = np.linalg.inv(pTwq.transpose(2, 3, 0, 1)) # [grow, gcol, 2, 2]
+ flag = False
+ except np.linalg.linalg.LinAlgError:
+ flag = True
+ det = np.linalg.det(pTwq.transpose(2, 3, 0, 1)) # [grow, gcol]
+ det[det < 1e-8] = np.inf
+ reshaped_det = det.reshape(1, 1, grow, gcol) # [1, 1, grow, gcol]
+ adjoint = pTwq[[[1, 0], [1, 0]], [[1, 1], [0, 0]], :, :] # [2, 2, grow, gcol]
+ adjoint[[0, 1], [1, 0], :, :] = -adjoint[[0, 1], [1, 0], :, :] # [2, 2, grow, gcol]
+ inv_pTwq = (adjoint / reshaped_det).transpose(2, 3, 0, 1) # [grow, gcol, 2, 2]
+ mul_left = reshaped_v - qstar # [2, grow, gcol]
+ reshaped_mul_left = mul_left.reshape(1, 2, grow, gcol).transpose(2, 3, 0, 1) # [grow, gcol, 1, 2]
+ mul_right = np.sum(reshaped_phat * reshaped_w * reshaped_phat2, axis=0) # [2, 2, grow, gcol]
+ reshaped_mul_right =mul_right.transpose(2, 3, 0, 1) # [grow, gcol, 2, 2]
+ temp = np.matmul(np.matmul(reshaped_mul_left, inv_pTwq), reshaped_mul_right) # [grow, gcol, 1, 2]
+ reshaped_temp = temp.reshape(grow, gcol, 2).transpose(2, 0, 1) # [2, grow, gcol]
+
+ # Get final image transfomer -- 3-D array
+ transformers = reshaped_temp + pstar # [2, grow, gcol]
+
+ # Correct the points where pTwp is singular
+ if flag:
+ blidx = det == np.inf # bool index
+ transformers[0][blidx] = vx[blidx] + qstar[0][blidx] - pstar[0][blidx]
+ transformers[1][blidx] = vy[blidx] + qstar[1][blidx] - pstar[1][blidx]
+
+ # Removed the points outside the border
+ transformers[transformers < 0] = 0
+ transformers[0][transformers[0] > height - 1] = 0
+ transformers[1][transformers[1] > width - 1] = 0
+
+ # Mapping original image
+ transformed_image = image[tuple(transformers.astype(np.int16))] # [grow, gcol]
+
+ # Rescale image
+ # transformed_image = rescale(transformed_image, scale=1.0 / density, mode='reflect')
+
+ return transformers.astype(np.float), transformed_image
+
+
+def mls_affine_deformation_inv_final(height, width, channel, p, q, alpha=1.0, density=1.0):
+ ''' Affine inverse deformation
+ ### Params:
+ * image - ndarray: original image
+ * p - ndarray: an array with size [n, 2], original control points
+ * q - ndarray: an array with size [n, 2], final control points
+ * alpha - float: parameter used by weights
+ * density - float: density of the grids
+ ### Return:
+ A deformed image.
+ '''
+ # height = image.shape[0]
+ # width = image.shape[1]
+ # Change (x, y) to (row, col)
+ q = q[:, [1, 0]]
+ p = p[:, [1, 0]]
+
+ # Make grids on the original image
+ gridX = np.linspace(0, width, num=int(width*density), endpoint=False)
+ gridY = np.linspace(0, height, num=int(height*density), endpoint=False)
+ vy, vx = np.meshgrid(gridX, gridY)
+ grow = vx.shape[0] # grid rows
+ gcol = vx.shape[1] # grid cols
+ ctrls = p.shape[0] # control points
+
+ # Compute
+ reshaped_p = p.reshape(ctrls, 2, 1, 1) # [ctrls, 2, 1, 1]
+ reshaped_q = q.reshape((ctrls, 2, 1, 1)) # [ctrls, 2, 1, 1]
+ reshaped_v = np.vstack((vx.reshape(1, grow, gcol), vy.reshape(1, grow, gcol))) # [2, grow, gcol]
+
+ w = 1.0 / np.sum((reshaped_p - reshaped_v) ** 2, axis=1)**alpha # [ctrls, grow, gcol]
+ w[w == np.inf] = 2**31 - 1
+ pstar = np.sum(w * reshaped_p.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol]
+ phat = reshaped_p - pstar # [ctrls, 2, grow, gcol]
+ qstar = np.sum(w * reshaped_q.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol]
+ qhat = reshaped_q - qstar # [ctrls, 2, grow, gcol]
+
+ reshaped_phat = phat.reshape(ctrls, 2, 1, grow, gcol) # [ctrls, 2, 1, grow, gcol]
+ reshaped_phat2 = phat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 2, 1, grow, gcol]
+ reshaped_qhat = qhat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol]
+ reshaped_w = w.reshape(ctrls, 1, 1, grow, gcol) # [ctrls, 1, 1, grow, gcol]
+ pTwq = np.sum(reshaped_phat * reshaped_w * reshaped_qhat, axis=0) # [2, 2, grow, gcol]
+ try:
+ inv_pTwq = np.linalg.inv(pTwq.transpose(2, 3, 0, 1)) # [grow, gcol, 2, 2]
+ flag = False
+ except np.linalg.linalg.LinAlgError:
+ flag = True
+ det = np.linalg.det(pTwq.transpose(2, 3, 0, 1)) # [grow, gcol]
+ det[det < 1e-8] = np.inf
+ reshaped_det = det.reshape(1, 1, grow, gcol) # [1, 1, grow, gcol]
+ adjoint = pTwq[[[1, 0], [1, 0]], [[1, 1], [0, 0]], :, :] # [2, 2, grow, gcol]
+ adjoint[[0, 1], [1, 0], :, :] = -adjoint[[0, 1], [1, 0], :, :] # [2, 2, grow, gcol]
+ inv_pTwq = (adjoint / reshaped_det).transpose(2, 3, 0, 1) # [grow, gcol, 2, 2]
+ mul_left = reshaped_v - qstar # [2, grow, gcol]
+ reshaped_mul_left = mul_left.reshape(1, 2, grow, gcol).transpose(2, 3, 0, 1) # [grow, gcol, 1, 2]
+ mul_right = np.sum(reshaped_phat * reshaped_w * reshaped_phat2, axis=0) # [2, 2, grow, gcol]
+ reshaped_mul_right =mul_right.transpose(2, 3, 0, 1) # [grow, gcol, 2, 2]
+ temp = np.matmul(np.matmul(reshaped_mul_left, inv_pTwq), reshaped_mul_right) # [grow, gcol, 1, 2]
+ reshaped_temp = temp.reshape(grow, gcol, 2).transpose(2, 0, 1) # [2, grow, gcol]
+
+ # Get final image transfomer -- 3-D array
+ transformers = reshaped_temp + pstar # [2, grow, gcol]
+
+ # Correct the points where pTwp is singular
+ if flag:
+ blidx = det == np.inf # bool index
+ transformers[0][blidx] = vx[blidx] + qstar[0][blidx] - pstar[0][blidx]
+ transformers[1][blidx] = vy[blidx] + qstar[1][blidx] - pstar[1][blidx]
+
+ # Removed the points outside the border
+ transformers[transformers < 0] = 0
+ transformers[0][transformers[0] > height - 1] = 0
+ transformers[1][transformers[1] > width - 1] = 0
+
+ return transformers
+
+def mls_similarity_deformation_inv(image, height, width, channel, p, q, alpha=1.0, density=1.0):
+ ''' Similarity inverse deformation
+ ### Params:
+ * image - ndarray: original image
+ * p - ndarray: an array with size [n, 2], original control points
+ * q - ndarray: an array with size [n, 2], final control points
+ * alpha - float: parameter used by weights
+ * density - float: density of the grids
+ ### Return:
+ A deformed image.
+ '''
+ height = image.shape[0]
+ width = image.shape[1]
+ # Change (x, y) to (row, col)
+ q = q[:, [1, 0]]
+ p = p[:, [1, 0]]
+
+ # Make grids on the original image
+ gridX = np.linspace(0, width, num=int(width*density), endpoint=False)
+ gridY = np.linspace(0, height, num=int(height*density), endpoint=False)
+ vy, vx = np.meshgrid(gridX, gridY)
+ grow = vx.shape[0] # grid rows
+ gcol = vx.shape[1] # grid cols
+ ctrls = p.shape[0] # control points
+
+ # Compute
+ reshaped_p = p.reshape(ctrls, 2, 1, 1) # [ctrls, 2, 1, 1]
+ reshaped_q = q.reshape((ctrls, 2, 1, 1)) # [ctrls, 2, 1, 1]
+ reshaped_v = np.vstack((vx.reshape(1, grow, gcol), vy.reshape(1, grow, gcol))) # [2, grow, gcol]
+
+ w = 1.0 / np.sum((reshaped_p - reshaped_v) ** 2, axis=1)**alpha # [ctrls, grow, gcol]
+ w[w == np.inf] = 2**31 - 1
+ pstar = np.sum(w * reshaped_p.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol]
+ phat = reshaped_p - pstar # [ctrls, 2, grow, gcol]
+ qstar = np.sum(w * reshaped_q.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol]
+ qhat = reshaped_q - qstar # [ctrls, 2, grow, gcol]
+ reshaped_phat1 = phat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol]
+ reshaped_phat2 = phat.reshape(ctrls, 2, 1, grow, gcol) # [ctrls, 2, 1, grow, gcol]
+ reshaped_qhat = qhat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol]
+ reshaped_w = w.reshape(ctrls, 1, 1, grow, gcol) # [ctrls, 1, 1, grow, gcol]
+
+ mu = np.sum(np.matmul(reshaped_w.transpose(0, 3, 4, 1, 2) *
+ reshaped_phat1.transpose(0, 3, 4, 1, 2),
+ reshaped_phat2.transpose(0, 3, 4, 1, 2)), axis=0) # [grow, gcol, 1, 1]
+ reshaped_mu = mu.reshape(1, grow, gcol) # [1, grow, gcol]
+ neg_phat_verti = phat[:, [1, 0],...] # [ctrls, 2, grow, gcol]
+ neg_phat_verti[:, 1,...] = -neg_phat_verti[:, 1,...]
+ reshaped_neg_phat_verti = neg_phat_verti.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol]
+ mul_right = np.concatenate((reshaped_phat1, reshaped_neg_phat_verti), axis=1) # [ctrls, 2, 2, grow, gcol]
+ mul_left = reshaped_qhat * reshaped_w # [ctrls, 1, 2, grow, gcol]
+ Delta = np.sum(np.matmul(mul_left.transpose(0, 3, 4, 1, 2),
+ mul_right.transpose(0, 3, 4, 1, 2)),
+ axis=0).transpose(0, 1, 3, 2) # [grow, gcol, 2, 1]
+ Delta_verti = Delta[...,[1, 0],:] # [grow, gcol, 2, 1]
+ Delta_verti[...,0,:] = -Delta_verti[...,0,:]
+ B = np.concatenate((Delta, Delta_verti), axis=3) # [grow, gcol, 2, 2]
+ try:
+ inv_B = np.linalg.inv(B) # [grow, gcol, 2, 2]
+ flag = False
+ except np.linalg.linalg.LinAlgError:
+ flag = True
+ det = np.linalg.det(B) # [grow, gcol]
+ det[det < 1e-8] = np.inf
+ reshaped_det = det.reshape(grow, gcol, 1, 1) # [grow, gcol, 1, 1]
+ adjoint = B[:,:,[[1, 0], [1, 0]], [[1, 1], [0, 0]]] # [grow, gcol, 2, 2]
+ adjoint[:,:,[0, 1], [1, 0]] = -adjoint[:,:,[0, 1], [1, 0]] # [grow, gcol, 2, 2]
+ inv_B = (adjoint / reshaped_det).transpose(2, 3, 0, 1) # [2, 2, grow, gcol]
+
+ v_minus_qstar_mul_mu = (reshaped_v - qstar) * reshaped_mu # [2, grow, gcol]
+
+ # Get final image transfomer -- 3-D array
+ reshaped_v_minus_qstar_mul_mu = v_minus_qstar_mul_mu.reshape(1, 2, grow, gcol) # [1, 2, grow, gcol]
+ transformers = np.matmul(reshaped_v_minus_qstar_mul_mu.transpose(2, 3, 0, 1),
+ inv_B).reshape(grow, gcol, 2).transpose(2, 0, 1) + pstar # [2, grow, gcol]
+
+ # Correct the points where pTwp is singular
+ if flag:
+ blidx = det == np.inf # bool index
+ transformers[0][blidx] = vx[blidx] + qstar[0][blidx] - pstar[0][blidx]
+ transformers[1][blidx] = vy[blidx] + qstar[1][blidx] - pstar[1][blidx]
+
+ # Removed the points outside the border
+ transformers[transformers < 0] = 0
+ transformers[0][transformers[0] > height - 1] = 0
+ transformers[1][transformers[1] > width - 1] = 0
+
+ # Mapping original image
+ transformed_image = image[tuple(transformers.astype(np.int16))] # [grow, gcol]
+
+ # Rescale image
+ # transformed_image = rescale(transformed_image, scale=1.0 / density, mode='reflect')
+
+ return transformers, transformed_image
+
+
+def mls_rigid_deformation_inv(image, height, width, channel, p, q, alpha=1.0, density=1.0):
+ ''' Rigid inverse deformation
+ ### Params:
+ * image - ndarray: original image
+ * p - ndarray: an array with size [n, 2], original control points
+ * q - ndarray: an array with size [n, 2], final control points
+ * alpha - float: parameter used by weights
+ * density - float: density of the grids
+ ### Return:
+ A deformed image.
+ '''
+ height = image.shape[0]
+ width = image.shape[1]
+ # Change (x, y) to (row, col)
+ q = q[:, [1, 0]]
+ p = p[:, [1, 0]]
+
+ # Make grids on the original image
+ gridX = np.linspace(0, width, num=int(width*density), endpoint=False)
+ gridY = np.linspace(0, height, num=int(height*density), endpoint=False)
+ vy, vx = np.meshgrid(gridX, gridY)
+ grow = vx.shape[0] # grid rows
+ gcol = vx.shape[1] # grid cols
+ ctrls = p.shape[0] # control points
+
+ # Compute
+ reshaped_p = p.reshape(ctrls, 2, 1, 1) # [ctrls, 2, 1, 1]
+ reshaped_q = q.reshape((ctrls, 2, 1, 1)) # [ctrls, 2, 1, 1]
+ reshaped_v = np.vstack((vx.reshape(1, grow, gcol), vy.reshape(1, grow, gcol))) # [2, grow, gcol]
+
+ w = 1.0 / np.sum((reshaped_p - reshaped_v) ** 2, axis=1)**alpha # [ctrls, grow, gcol]
+ w[w == np.inf] = 2**31 - 1
+ pstar = np.sum(w * reshaped_p.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol]
+ phat = reshaped_p - pstar # [ctrls, 2, grow, gcol]
+ qstar = np.sum(w * reshaped_q.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol]
+ qhat = reshaped_q - qstar # [ctrls, 2, grow, gcol]
+ reshaped_phat1 = phat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol]
+ reshaped_phat2 = phat.reshape(ctrls, 2, 1, grow, gcol) # [ctrls, 2, 1, grow, gcol]
+ reshaped_qhat = qhat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol]
+ reshaped_w = w.reshape(ctrls, 1, 1, grow, gcol) # [ctrls, 1, 1, grow, gcol]
+
+ mu = np.sum(np.matmul(reshaped_w.transpose(0, 3, 4, 1, 2) *
+ reshaped_phat1.transpose(0, 3, 4, 1, 2),
+ reshaped_phat2.transpose(0, 3, 4, 1, 2)), axis=0) # [grow, gcol, 1, 1]
+ reshaped_mu = mu.reshape(1, grow, gcol) # [1, grow, gcol]
+ neg_phat_verti = phat[:, [1, 0],...] # [ctrls, 2, grow, gcol]
+ neg_phat_verti[:, 1,...] = -neg_phat_verti[:, 1,...]
+ reshaped_neg_phat_verti = neg_phat_verti.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol]
+ mul_right = np.concatenate((reshaped_phat1, reshaped_neg_phat_verti), axis=1) # [ctrls, 2, 2, grow, gcol]
+ mul_left = reshaped_qhat * reshaped_w # [ctrls, 1, 2, grow, gcol]
+ Delta = np.sum(np.matmul(mul_left.transpose(0, 3, 4, 1, 2),
+ mul_right.transpose(0, 3, 4, 1, 2)),
+ axis=0).transpose(0, 1, 3, 2) # [grow, gcol, 2, 1]
+ Delta_verti = Delta[...,[1, 0],:] # [grow, gcol, 2, 1]
+ Delta_verti[...,0,:] = -Delta_verti[...,0,:]
+ B = np.concatenate((Delta, Delta_verti), axis=3) # [grow, gcol, 2, 2]
+ try:
+ inv_B = np.linalg.inv(B) # [grow, gcol, 2, 2]
+ flag = False
+ except np.linalg.linalg.LinAlgError:
+ flag = True
+ det = np.linalg.det(B) # [grow, gcol]
+ det[det < 1e-8] = np.inf
+ reshaped_det = det.reshape(grow, gcol, 1, 1) # [grow, gcol, 1, 1]
+ adjoint = B[:,:,[[1, 0], [1, 0]], [[1, 1], [0, 0]]] # [grow, gcol, 2, 2]
+ adjoint[:,:,[0, 1], [1, 0]] = -adjoint[:,:,[0, 1], [1, 0]] # [grow, gcol, 2, 2]
+ inv_B = (adjoint / reshaped_det).transpose(2, 3, 0, 1) # [2, 2, grow, gcol]
+
+ vqstar = reshaped_v - qstar # [2, grow, gcol]
+ reshaped_vqstar = vqstar.reshape(1, 2, grow, gcol) # [1, 2, grow, gcol]
+
+ # Get final image transfomer -- 3-D array
+ temp = np.matmul(reshaped_vqstar.transpose(2, 3, 0, 1),
+ inv_B).reshape(grow, gcol, 2).transpose(2, 0, 1) # [2, grow, gcol]
+ norm_temp = np.linalg.norm(temp, axis=0, keepdims=True) # [1, grow, gcol]
+ norm_vqstar = np.linalg.norm(vqstar, axis=0, keepdims=True) # [1, grow, gcol]
+ transformers = temp / norm_temp * norm_vqstar + pstar # [2, grow, gcol]
+
+ # Correct the points where pTwp is singular
+ if flag:
+ blidx = det == np.inf # bool index
+ transformers[0][blidx] = vx[blidx] + qstar[0][blidx] - pstar[0][blidx]
+ transformers[1][blidx] = vy[blidx] + qstar[1][blidx] - pstar[1][blidx]
+
+ # Removed the points outside the border
+ transformers[transformers < 0] = 0
+ transformers[0][transformers[0] > height - 1] = 0
+ transformers[1][transformers[1] > width - 1] = 0
+
+ # Mapping original image
+ transformed_image = image[tuple(transformers.astype(np.int16))] # [grow, gcol]
+
+ # Rescale image
+ # transformed_image = rescale(transformed_image, scale=1.0 / density, mode='reflect')
+
+ return transformers, transformed_image
+
+# mls rigid algorithm
+def mls_rigid_deformation_inv_wy(image, height, width, channel, p, q, alpha=1.0, density=1.0):
+ '''
+ Rigid inverse deformation
+ ### Params:
+ * image - ndarray: original image
+ * p - ndarray: an array with size [n, 2], original control points
+ * q - ndarray: an array with size [n, 2], final control points
+ * alpha - float: parameter used by weights
+ * density - float: density of the grids
+ ### Return:
+ A deformed image.
+ '''
+ # Change (x, y) to (row, col)
+ q = q[:, [1, 0]]
+ p = p[:, [1, 0]]
+
+ # Make grids on the original image
+ gridX = torch.linspace(0, width, steps=int(width * density))
+ gridY = torch.linspace(0, height, steps=int(height * density))
+ vx, vy = torch.meshgrid(gridY, gridX)
+
+ grow = vx.shape[0] # grid rows
+ gcol = vx.shape[1] # grid cols
+ ctrls = p.shape[0] # control points
+
+ # Compute
+ reshaped_p = p.reshape(ctrls, 2, 1, 1) # [ctrls, 2, 1, 1]
+ reshaped_q = q.reshape((ctrls, 2, 1, 1)) # [ctrls, 2, 1, 1]
+ reshaped_v = torch.stack((vx, vy), dim=0) # [2, grow, gcol]
+
+ w = 1.0 / torch.sum((reshaped_p - reshaped_v) ** 2, dim=1) ** alpha # [ctrls, grow, gcol]
+ w[w == torch.tensor(float("inf"))] = 2 ** 31 - 1
+ pstar = torch.sum(w * reshaped_p.permute(1, 0, 2, 3), dim=1) / torch.sum(w, dim=0) # [2, grow, gcol]
+ phat = reshaped_p - pstar # [ctrls, 2, grow, gcol]
+ qstar = torch.sum(w * reshaped_q.permute(1, 0, 2, 3), dim=1) / torch.sum(w, dim=0) # [2, grow, gcol]
+ qhat = reshaped_q - qstar # [ctrls, 2, grow, gcol]
+ reshaped_phat1 = phat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol]
+ reshaped_qhat = qhat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol]
+ reshaped_w = w.reshape(ctrls, 1, 1, grow, gcol) # [ctrls, 1, 1, grow, gcol]
+
+ neg_phat_verti = phat[:, [1, 0], ...] # [ctrls, 2, grow, gcol]
+ neg_phat_verti[:, 1, ...] = -neg_phat_verti[:, 1, ...]
+ reshaped_neg_phat_verti = neg_phat_verti.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol]
+ mul_right = torch.cat((reshaped_phat1, reshaped_neg_phat_verti), dim=1) # [ctrls, 2, 2, grow, gcol]
+ mul_left = reshaped_qhat * reshaped_w # [ctrls, 1, 2, grow, gcol]
+ Delta = torch.sum(torch.matmul(mul_left.permute(0, 3, 4, 1, 2),
+ mul_right.permute(0, 3, 4, 1, 2)),
+ dim=0).permute(0, 1, 3, 2) # [grow, gcol, 2, 1]
+ Delta_verti = Delta[..., [1, 0], :] # [grow, gcol, 2, 1]
+ Delta_verti[..., 0, :] = -Delta_verti[..., 0, :]
+ B = torch.cat((Delta, Delta_verti), dim=3) # [grow, gcol, 2, 2]
+ try:
+ inv_B = torch.inverse(B) # [grow, gcol, 2, 2]
+ flag = False
+ except:
+ flag = True
+ det = np.linalg.det(B.numpy())
+ det = torch.from_numpy(det) # [grow, gcol]
+ det[det < 1e-8] = torch.tensor(float("inf"))
+ reshaped_det = det.reshape(grow, gcol, 1, 1) # [grow, gcol, 1, 1]
+ adjoint = B[:, :, [[1, 0], [1, 0]], [[1, 1], [0, 0]]] # [grow, gcol, 2, 2]
+ adjoint[:, :, [0, 1], [1, 0]] = -adjoint[:, :, [0, 1], [1, 0]] # [grow, gcol, 2, 2]
+ inv_B = (adjoint / reshaped_det).permute(2, 3, 0, 1) # [2, 2, grow, gcol]
+
+ vqstar = reshaped_v - qstar # [2, grow, gcol]
+ reshaped_vqstar = vqstar.reshape(1, 2, grow, gcol) # [1, 2, grow, gcol]
+
+ # Get final image transfomer -- 3-D array
+ temp = torch.matmul(reshaped_vqstar.permute(2, 3, 0, 1),
+ inv_B).reshape(grow, gcol, 2).permute(2, 0, 1) # [2, grow, gcol]
+ norm_temp = torch.norm(temp, dim=0, keepdim=True) # [1, grow, gcol]
+ norm_vqstar = torch.norm(vqstar, dim=0, keepdim=True) # [1, grow, gcol]
+ transformers = temp / (norm_temp + 1e-10) * norm_vqstar + pstar # [2, grow, gcol]
+
+ # Correct the points where pTwp is singular
+ if flag:
+ blidx = det == torch.tensor(float("inf")) # bool index
+ transformers[0][blidx] = vx[blidx] + qstar[0][blidx] - pstar[0][blidx]
+ transformers[1][blidx] = vy[blidx] + qstar[1][blidx] - pstar[1][blidx]
+
+ # Removed the points outside the border
+ transformers[transformers < 0] = 0
+ transformers[0][transformers[0] > height - 1] = 0
+ transformers[1][transformers[1] > width - 1] = 0
+
+ # Mapping original image
+
+ transformed_image = image[tuple(transformers.numpy().astype(np.int16))] # [grow, gcol]
+
+ # # Rescale image
+ # img_h, img_w, _ = transformed_image.shape
+ # transformed_image = transformed_image.resize_(int(img_h/density), int(img_w/density), channel)
+
+ return transformers.numpy(), transformed_image
+
+
+
+
+
+
+
+
+# mls for whole feature, instead of roi align
+def roi_mls_whole(feature, d_point, g_point, step=1):
+ '''
+ :param feature: itorchut guidance feature [C, H, W]
+ :param d_point: landmark for degraded feature [N, 2]
+ :param g_point: landmark for guidance feature [N, 2]
+ :param step: step of landmark choose, number of control points: landmark_number/step
+ :return: transformed feature [C, H, W]
+ '''
+ # feature 3 * 256 * 256
+
+ channel = feature.size(0)
+ height = feature.size(1) # 256 * 256
+ width = feature.size(2)
+
+ # ignore the boarder point of face
+ g_land = g_point[0::step, :]
+ d_land = d_point[0::step, :]
+
+ # mls
+
+ featureTmp = feature.permute(1,2,0)
+ # grid, timg = mls_rigid_deformation_inv_wy(featureTmp.cpu(), height, width, channel, g_land.cpu(), d_land.cpu(), density=1.) # 2 * 256 * 256 # wenyu
+ grid, timg = mls_affine_deformation_inv(featureTmp.cpu(), height, width, channel, g_land.cpu().numpy(), d_land.cpu().numpy(), density=1.) #affine prefered
+ # grid, timg_sim = mls_similarity_deformation_inv(featureTmp.cpu(), height, width, channel, g_land.cpu().numpy(), d_land.cpu().numpy(), density=1.) # similarity
+ # grid, timg_rigid = mls_rigid_deformation_inv(featureTmp.cpu(), height, width, channel, g_land.cpu().numpy(), d_land.cpu().numpy(), density=1.) #rigid
+
+
+ grid = (grid - 127.5) / 127.5
+ gridNew = torch.from_numpy(grid[[1,0],:,:]).float().permute(1,2,0).unsqueeze(0)
+
+ if torch.cuda.is_available():
+ gridNew = gridNew.cuda()
+ featureNew = feature.unsqueeze(0)
+ # warp_feature = F.grid_sample(featureNew,gridNew.cuda(),mode='nearest')
+ warp_feature = F.grid_sample(featureNew,gridNew.cuda())
+
+ # HWC -> CHW
+
+
+ # _, timg_affine = mls_affine_deformation_inv(featureTmp.cpu(), height, width, channel, g_land.cpu().numpy(), d_land.cpu().numpy(), density=1.)
+ # _, timg_sim = mls_similarity_deformation_inv(featureTmp.cpu(), height, width, channel, g_land.cpu().numpy(), d_land.cpu().numpy(), density=1.)
+ # _, timg_rigid = mls_rigid_deformation_inv(featureTmp.cpu(), height, width, channel, g_land.cpu().numpy(), d_land.cpu().numpy(), density=1.)
+
+ # return warp_feature.squeeze(), timg.permute(2,0,1) #, timg_sim.permute(2,0,1), timg_rigid.permute(2,0,1)
+ return warp_feature.squeeze(), gridNew.cuda()
+
+
+def roi_mls_whole_final(feature, d_point, g_point, step=1):
+ '''
+ :param feature: itorchut guidance feature [C, H, W]
+ :param d_point: landmark for degraded feature [N, 2]
+ :param g_point: landmark for guidance feature [N, 2]
+ :param step: step of landmark choose, number of control points: landmark_number/step
+ :return: transformed feature [C, H, W]
+ '''
+ # feature 3 * 256 * 256
+
+ channel = feature.size(0)
+ height = feature.size(1) # 256 * 256
+ width = feature.size(2)
+
+ # ignore the boarder point of face
+ g_land = g_point[0::step, :]
+ d_land = d_point[0::step, :]
+
+ # mls
+
+ # featureTmp = feature.permute(1,2,0)
+ # grid, timg = mls_rigid_deformation_inv_wy(featureTmp.cpu(), height, width, channel, g_land.cpu(), d_land.cpu(), density=1.) # 2 * 256 * 256 # wenyu
+ grid = mls_affine_deformation_inv_final(height, width, channel, g_land.cpu().numpy(), d_land.cpu().numpy(), density=1.) #affine prefered
+ # grid, timg_sim = mls_similarity_deformation_inv(featureTmp.cpu(), height, width, channel, g_land.cpu().numpy(), d_land.cpu().numpy(), density=1.) # similarity
+ # grid, timg_rigid = mls_rigid_deformation_inv(featureTmp.cpu(), height, width, channel, g_land.cpu().numpy(), d_land.cpu().numpy(), density=1.) #rigid
+
+ grid = (grid - height/2) / (height/2)
+
+ gridNew = torch.from_numpy(grid[[1,0],:,:]).float().permute(1,2,0).unsqueeze(0)
+
+ if torch.cuda.is_available():
+ gridNew = gridNew.cuda()
+ # HWC -> CHW
+ return gridNew
\ No newline at end of file
diff --git a/util/get_data.py b/util/get_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..6325605bc68ec3b4036a4e0f42c28e0b8965867d
--- /dev/null
+++ b/util/get_data.py
@@ -0,0 +1,115 @@
+from __future__ import print_function
+import os
+import tarfile
+import requests
+from warnings import warn
+from zipfile import ZipFile
+from bs4 import BeautifulSoup
+from os.path import abspath, isdir, join, basename
+
+
+class GetData(object):
+ """
+
+ Download CycleGAN or Pix2Pix Data.
+
+ Args:
+ technique : str
+ One of: 'cyclegan' or 'pix2pix'.
+ verbose : bool
+ If True, print additional information.
+
+ Examples:
+ >>> from util.get_data import GetData
+ >>> gd = GetData(technique='cyclegan')
+ >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
+
+ """
+
+ def __init__(self, technique='cyclegan', verbose=True):
+ url_dict = {
+ 'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets',
+ 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
+ }
+ self.url = url_dict.get(technique.lower())
+ self._verbose = verbose
+
+ def _print(self, text):
+ if self._verbose:
+ print(text)
+
+ @staticmethod
+ def _get_options(r):
+ soup = BeautifulSoup(r.text, 'lxml')
+ options = [h.text for h in soup.find_all('a', href=True)
+ if h.text.endswith(('.zip', 'tar.gz'))]
+ return options
+
+ def _present_options(self):
+ r = requests.get(self.url)
+ options = self._get_options(r)
+ print('Options:\n')
+ for i, o in enumerate(options):
+ print("{0}: {1}".format(i, o))
+ choice = input("\nPlease enter the number of the "
+ "dataset above you wish to download:")
+ return options[int(choice)]
+
+ def _download_data(self, dataset_url, save_path):
+ if not isdir(save_path):
+ os.makedirs(save_path)
+
+ base = basename(dataset_url)
+ temp_save_path = join(save_path, base)
+
+ with open(temp_save_path, "wb") as f:
+ r = requests.get(dataset_url)
+ f.write(r.content)
+
+ if base.endswith('.tar.gz'):
+ obj = tarfile.open(temp_save_path)
+ elif base.endswith('.zip'):
+ obj = ZipFile(temp_save_path, 'r')
+ else:
+ raise ValueError("Unknown File Type: {0}.".format(base))
+
+ self._print("Unpacking Data...")
+ obj.extractall(save_path)
+ obj.close()
+ os.remove(temp_save_path)
+
+ def get(self, save_path, dataset=None):
+ """
+
+ Download a dataset.
+
+ Args:
+ save_path : str
+ A directory to save the data to.
+ dataset : str, optional
+ A specific dataset to download.
+ Note: this must include the file extension.
+ If None, options will be presented for you
+ to choose from.
+
+ Returns:
+ save_path_full : str
+ The absolute path to the downloaded data.
+
+ """
+ if dataset is None:
+ selected_dataset = self._present_options()
+ else:
+ selected_dataset = dataset
+
+ save_path_full = join(save_path, selected_dataset.split('.')[0])
+
+ if isdir(save_path_full):
+ warn("\n'{0}' already exists. Voiding Download.".format(
+ save_path_full))
+ else:
+ self._print('Downloading Data...')
+ url = "{0}/{1}".format(self.url, selected_dataset)
+ self._download_data(url, save_path=save_path)
+
+ return abspath(save_path_full)
diff --git a/util/html.py b/util/html.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7956f1353fd25aee253e39a6178481b0b330621
--- /dev/null
+++ b/util/html.py
@@ -0,0 +1,64 @@
+import dominate
+from dominate.tags import *
+import os
+
+
+class HTML:
+ def __init__(self, web_dir, title, reflesh=0):
+ self.title = title
+ self.web_dir = web_dir
+ self.img_dir = os.path.join(self.web_dir, 'images')
+ if not os.path.exists(self.web_dir):
+ os.makedirs(self.web_dir)
+ if not os.path.exists(self.img_dir):
+ os.makedirs(self.img_dir)
+ # print(self.img_dir)
+
+ self.doc = dominate.document(title=title)
+ if reflesh > 0:
+ with self.doc.head:
+ meta(http_equiv="reflesh", content=str(reflesh))
+
+ def get_image_dir(self):
+ return self.img_dir
+
+ def add_header(self, str):
+ with self.doc:
+ h3(str)
+
+ def add_table(self, border=1):
+ self.t = table(border=border, style="table-layout: fixed;")
+ self.doc.add(self.t)
+
+ def add_images(self, ims, txts, links, width=400):
+ self.add_table()
+ with self.t:
+ with tr():
+ for im, txt, link in zip(ims, txts, links):
+ with td(style="word-wrap: break-word;", halign="center", valign="top"):
+ with p():
+ with a(href=os.path.join('images', link)):
+ img(style="width:%dpx" % width, src=os.path.join('images', im))
+ br()
+ p(txt)
+
+ def save(self):
+ html_file = '%s/index.html' % self.web_dir
+ f = open(html_file, 'wt')
+ f.write(self.doc.render())
+ f.close()
+
+
+if __name__ == '__main__':
+ html = HTML('web/', 'test_html')
+ html.add_header('hello world')
+
+ ims = []
+ txts = []
+ links = []
+ for n in range(4):
+ ims.append('image_%d.png' % n)
+ txts.append('text_%d' % n)
+ links.append('image_%d.png' % n)
+ html.add_images(ims, txts, links)
+ html.save()
diff --git a/util/image_pool.py b/util/image_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..52413e0f8a45a8c8511bf103d3aabd537fac97b9
--- /dev/null
+++ b/util/image_pool.py
@@ -0,0 +1,32 @@
+import random
+import torch
+
+
+class ImagePool():
+ def __init__(self, pool_size):
+ self.pool_size = pool_size
+ if self.pool_size > 0:
+ self.num_imgs = 0
+ self.images = []
+
+ def query(self, images):
+ if self.pool_size == 0:
+ return images
+ return_images = []
+ for image in images:
+ image = torch.unsqueeze(image.data, 0)
+ if self.num_imgs < self.pool_size:
+ self.num_imgs = self.num_imgs + 1
+ self.images.append(image)
+ return_images.append(image)
+ else:
+ p = random.uniform(0, 1)
+ if p > 0.5:
+ random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
+ tmp = self.images[random_id].clone()
+ self.images[random_id] = image
+ return_images.append(tmp)
+ else:
+ return_images.append(image)
+ return_images = torch.cat(return_images, 0)
+ return return_images
diff --git a/util/util.py b/util/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..7206cea2d9188bc292a5f640096c0a01ec682b42
--- /dev/null
+++ b/util/util.py
@@ -0,0 +1,140 @@
+# -- coding: utf-8 --
+from __future__ import print_function
+import torch
+import numpy as np
+from PIL import Image
+import os
+import torchvision
+
+# Converts a Tensor into an image array (numpy)
+# |imtype|: the desired type of the converted numpy array
+def tensor2im(input_image, norm=1, imtype=np.uint8):
+ if isinstance(input_image, torch.Tensor):
+ image_tensor = input_image.data
+ else:
+ return input_image
+ if norm == 1: #for clamp -1 to 1
+ image_numpy = image_tensor[0].cpu().float().clamp_(-1,1).numpy()
+ elif norm == 2: # for norm through max-min
+ image_ = image_tensor[0].cpu().float()
+ max_ = torch.max(image_)
+ min_ = torch.min(image_)
+ image_numpy = (image_ - min_)/(max_-min_)*2-1
+ image_numpy = image_numpy.numpy()
+ else:
+ pass
+ if image_numpy.shape[0] == 1:
+ image_numpy = np.tile(image_numpy, (3, 1, 1))
+ # print(image_numpy.shape)
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
+ # print(image_numpy.shape)
+ return image_numpy.astype(imtype)
+def tensor2im3Channels(input_image, imtype=np.uint8):
+ if isinstance(input_image, torch.Tensor):
+ image_tensor = input_image.data
+ else:
+ return input_image
+
+ image_numpy = image_tensor.cpu().float().clamp_(-1,1).numpy()
+
+ # print(image_numpy.shape)
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
+ # print(image_numpy.shape)
+ return image_numpy.astype(imtype)
+
+def diagnose_network(net, name='network'):
+ mean = 0.0
+ count = 0
+ for param in net.parameters():
+ if param.grad is not None:
+ mean += torch.mean(torch.abs(param.grad.data))
+ count += 1
+ if count > 0:
+ mean = mean / count
+ print(name)
+ print(mean)
+
+
+
+
+def print_numpy(x, val=True, shp=False):
+ x = x.astype(np.float64)
+ if shp:
+ print('shape,', x.shape)
+ if val:
+ x = x.flatten()
+ print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
+ np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
+
+
+def mkdirs(paths):
+ if isinstance(paths, list) and not isinstance(paths, str):
+ for path in paths:
+ mkdir(path)
+ else:
+ mkdir(paths)
+
+
+def mkdir(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+
+def print_current_losses(epoch, i, losses, t, t_data):
+ message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data)
+ for k, v in losses.items():
+ message += '%s: %.3f ' % (k, v)
+
+ print(message)
+ # with open('', "a") as log_file:
+ # log_file.write('%s\n' % message)
+
+def display_current_results(writer,visuals,losses,step,save_result):
+ for label, images in visuals.items():
+ if 'Mask' in label:# or 'Scale' in label:
+ grid = torchvision.utils.make_grid(images,normalize=False, scale_each=True)
+ # pass
+ else:
+ pass
+ grid = torchvision.utils.make_grid(images,normalize=True, scale_each=True)
+ writer.add_image(label,grid,step)
+ for k,v in losses.items():
+ writer.add_scalar(k,v,step)
+
+def VisualFeature(input_feature, imtype=np.uint8):
+ if isinstance(input_feature, torch.Tensor):
+ image_tensor = input_feature.data
+ else:
+ return input_feature
+
+ image_ = image_tensor.cpu().float()
+
+ if image_.size(1) == 3:
+ image_ = image_.permute(1,2,0)
+
+ # assert(image_.size(1) == 1)
+
+
+
+ #####norm 0 to 1
+ max_ = torch.max(image_)
+ min_ = torch.min(image_)
+ image_numpy = (image_ - min_)/(max_-min_)*2-1
+ image_numpy = image_numpy.numpy()
+ image_numpy = (image_numpy + 1) / 2.0 * 255.0
+ #####no norm
+ # print((max_,min_))
+ # image_numpy = image_.numpy()
+ # image_numpy = image_numpy*255.0
+
+
+ # print('wwwwwwwwwwwwww')
+ # print(max_)
+ # print(min_)
+ # print(image_numpy.shape)
+ return image_numpy.astype(imtype)
+
+
+def save_image(image_numpy, image_path):
+ image_pil = Image.fromarray(image_numpy)
+ image_pil.save(image_path)
diff --git a/util/visualizer.py b/util/visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..77962a8750eb943e2d5c3d0052a02586d7115e08
--- /dev/null
+++ b/util/visualizer.py
@@ -0,0 +1,242 @@
+import numpy as np
+import os
+import ntpath
+import time
+from . import util
+from . import html
+# from scipy.misc import imresize
+
+
+# save image to the disk
+def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
+ image_dir = webpage.get_image_dir()
+ if isinstance(image_path,list):
+ image_path = image_path[0]
+ short_path = ntpath.basename(image_path)
+
+ name = os.path.splitext(short_path)[0]
+
+ webpage.add_header(name)
+ ims, txts, links = [], [], []
+
+ for label, im_data in visuals.items():
+ im = util.tensor2im(im_data)
+ # image_name = '%s_%s.png' % (name, label)
+ # save_path = os.path.join(image_dir, image_name)
+ image_name = '%s.png' % (name)
+
+ util.mkdirs(os.path.join(image_dir, label))
+ save_path = os.path.join(image_dir, label, image_name)
+
+ h, w, _ = im.shape
+# if aspect_ratio > 1.0:
+# im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic')
+# if aspect_ratio < 1.0:
+# im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic')
+
+ util.save_image(im, save_path)
+
+ link_name = os.path.join(label,image_name)
+ ims.append(link_name)
+ txts.append(label)
+ links.append(link_name)
+ webpage.add_images(ims, txts, links, width=width)
+
+def save_crop(visuals, save_path):
+ im_data = visuals['fake_A']
+ im = util.tensor2im(im_data)
+ util.save_image(im, save_path)
+
+# save image to the disk
+def save_images_test(webpage, visuals, image_dir, image_path, aspect_ratio=1.0, width=256):
+ # image_dir = webpage.get_image_dir()
+ if isinstance(image_path,list):
+ image_path = image_path[0]
+ short_path = ntpath.basename(image_path)
+
+ name = os.path.splitext(short_path)[0]
+
+ webpage.add_header(name)
+ ims, txts, links = [], [], []
+
+ for label, im_data in visuals.items():
+ im = util.tensor2im(im_data)
+ # image_name = '%s_%s.png' % (name, label)
+ # save_path = os.path.join(image_dir, image_name)
+ image_name = '%s.png' % (name)
+
+ util.mkdirs(os.path.join(image_dir, label))
+ save_path = os.path.join(image_dir, label, image_name)
+
+ exit('wwwwwwwwwwwww')
+
+
+ h, w, _ = im.shape
+# if aspect_ratio > 1.0:
+# im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic')
+# if aspect_ratio < 1.0:
+# im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic')
+
+ util.save_image(im, save_path)
+
+ link_name = os.path.join(label,image_name)
+ ims.append(link_name)
+ txts.append(label)
+ links.append(link_name)
+ webpage.add_images(ims, txts, links, width=width)
+
+# save image to the disk
+# Original Version
+# def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
+# image_dir = webpage.get_image_dir()
+# short_path = ntpath.basename(image_path[0])
+# name = os.path.splitext(short_path)[0]
+
+# webpage.add_header(name)
+# ims, txts, links = [], [], []
+
+# for label, im_data in visuals.items():
+# im = util.tensor2im(im_data)
+# image_name = '%s_%s.png' % (name, label)
+# save_path = os.path.join(image_dir, image_name)
+# h, w, _ = im.shape
+# if aspect_ratio > 1.0:
+# im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic')
+# if aspect_ratio < 1.0:
+# im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic')
+# util.save_image(im, save_path)
+
+# ims.append(image_name)
+# txts.append(label)
+# links.append(image_name)
+# webpage.add_images(ims, txts, links, width=width)
+
+
+class Visualizer():
+ def __init__(self, opt):
+ self.display_id = opt.display_id
+ self.use_html = opt.isTrain and not opt.no_html
+ self.win_size = opt.display_winsize
+ self.name = opt.name
+ self.opt = opt
+ self.saved = False
+ if self.display_id > 0:
+ import visdom
+ self.ncols = opt.display_ncols
+ self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env, raise_exceptions=True)
+
+ if self.use_html:
+ self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
+ self.img_dir = os.path.join(self.web_dir, 'images')
+ print('create web directory %s...' % self.web_dir)
+ util.mkdirs([self.web_dir, self.img_dir])
+ self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
+ with open(self.log_name, "a") as log_file:
+ now = time.strftime("%c")
+ log_file.write('================ Training Loss (%s) ================\n' % now)
+
+ def reset(self):
+ self.saved = False
+
+ def throw_visdom_connection_error(self):
+ print('\n\nCould not connect to Visdom server (https://github.com/facebookresearch/visdom) for displaying training progress.\nYou can suppress connection to Visdom using the option --display_id -1. To install visdom, run \n$ pip install visdom\n, and start the server by \n$ python -m visdom.server.\n\n')
+ exit(1)
+
+ # |visuals|: dictionary of images to display or save
+ def display_current_results(self, visuals, epoch, save_result):
+ if self.display_id > 0: # show images in the browser
+ ncols = self.ncols
+ if ncols > 0:
+ ncols = min(ncols, len(visuals))
+ h, w = next(iter(visuals.values())).shape[:2]
+ table_css = """""" % (w, h)
+ title = self.name
+ label_html = ''
+ label_html_row = ''
+ images = []
+ idx = 0
+ for label, image in visuals.items():
+ image_numpy = util.tensor2im(image)
+ label_html_row += '%s | ' % label
+ images.append(image_numpy.transpose([2, 0, 1]))
+ idx += 1
+ if idx % ncols == 0:
+ label_html += '%s
' % label_html_row
+ label_html_row = ''
+ white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
+ while idx % ncols != 0:
+ images.append(white_image)
+ label_html_row += ' | '
+ idx += 1
+ if label_html_row != '':
+ label_html += '%s
' % label_html_row
+ # pane col = image row
+ try:
+ self.vis.images(images, nrow=ncols, win=self.display_id + 1,
+ padding=2, opts=dict(title=title + ' images'))
+ label_html = '' % label_html
+ self.vis.text(table_css + label_html, win=self.display_id + 2,
+ opts=dict(title=title + ' labels'))
+ except ConnectionError:
+ self.throw_visdom_connection_error()
+
+ else:
+ idx = 1
+ for label, image in visuals.items():
+ image_numpy = util.tensor2im(image)
+ self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
+ win=self.display_id + idx)
+ idx += 1
+
+ if self.use_html and (save_result or not self.saved): # save images to a html file
+ self.saved = True
+ for label, image in visuals.items():
+ image_numpy = util.tensor2im(image)
+ img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
+ util.save_image(image_numpy, img_path)
+ # update website
+ webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)
+ for n in range(epoch, 0, -1):
+ webpage.add_header('epoch [%d]' % n)
+ ims, txts, links = [], [], []
+
+ for label, image_numpy in visuals.items():
+ image_numpy = util.tensor2im(image)
+ img_path = 'epoch%.3d_%s.png' % (n, label)
+ ims.append(img_path)
+ txts.append(label)
+ links.append(img_path)
+ webpage.add_images(ims, txts, links, width=self.win_size)
+ webpage.save()
+
+ # losses: dictionary of error labels and values
+ def plot_current_losses(self, epoch, counter_ratio, opt, losses):
+ if not hasattr(self, 'plot_data'):
+ self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
+ self.plot_data['X'].append(epoch + counter_ratio)
+ self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
+ try:
+ self.vis.line(
+ X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
+ Y=np.array(self.plot_data['Y']),
+ opts={
+ 'title': self.name + ' loss over time',
+ 'legend': self.plot_data['legend'],
+ 'xlabel': 'epoch',
+ 'ylabel': 'loss'},
+ win=self.display_id)
+ except ConnectionError:
+ self.throw_visdom_connection_error()
+
+ # losses: same format as |losses| of plot_current_losses
+ def print_current_losses(self, epoch, i, losses, t, t_data):
+ message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data)
+ for k, v in losses.items():
+ message += '%s: %.3f ' % (k, v)
+
+ print(message)
+ with open(self.log_name, "a") as log_file:
+ log_file.write('%s\n' % message)