diff --git a/FaceLandmarkDetection/Dockerfile b/FaceLandmarkDetection/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..39c5d4a5841ee70227428c47c8c02de4a588e150 --- /dev/null +++ b/FaceLandmarkDetection/Dockerfile @@ -0,0 +1,33 @@ +# Based on https://github.com/pytorch/pytorch/blob/master/Dockerfile +FROM nvidia/cuda:9.2-cudnn7-devel-ubuntu16.04 + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + cmake \ + git \ + curl \ + vim \ + ca-certificates \ + libboost-all-dev \ + python-qt4 \ + libjpeg-dev \ + libpng-dev &&\ + rm -rf /var/lib/apt/lists/* + +RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ + chmod +x ~/miniconda.sh && \ + ~/miniconda.sh -b -p /opt/conda && \ + rm ~/miniconda.sh + +ENV PATH /opt/conda/bin:$PATH + +RUN conda config --set always_yes yes --set changeps1 no && conda update -q conda +RUN conda install pytorch torchvision cuda92 -c pytorch + +# Install face-alignment package +WORKDIR /workspace +RUN chmod -R a+w /workspace +RUN git clone https://github.com/1adrianb/face-alignment +WORKDIR /workspace/face-alignment +RUN pip install -r requirements.txt +RUN python setup.py install diff --git a/FaceLandmarkDetection/LICENSE b/FaceLandmarkDetection/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..ed4c6d1af07162de4ff7d0192fdbbf42acd77beb --- /dev/null +++ b/FaceLandmarkDetection/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2017, Adrian Bulat +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/FaceLandmarkDetection/README.md b/FaceLandmarkDetection/README.md new file mode 100644 index 0000000000000000000000000000000000000000..211370c6bc8954d4285ae5575b1a6401e19b9874 --- /dev/null +++ b/FaceLandmarkDetection/README.md @@ -0,0 +1,183 @@ +# Face Recognition + +Detect facial landmarks from Python using the world's most accurate face alignment network, capable of detecting points in both 2D and 3D coordinates. + +Build using [FAN](https://www.adrianbulat.com)'s state-of-the-art deep learning based face alignment method. + +

+ +**Note:** The lua version is available [here](https://github.com/1adrianb/2D-and-3D-face-alignment). + +For numerical evaluations it is highly recommended to use the lua version which uses indentical models with the ones evaluated in the paper. More models will be added soon. + +[![License](https://img.shields.io/badge/License-BSD%203--Clause-blue.svg)](https://opensource.org/licenses/BSD-3-Clause) [![Build Status](https://travis-ci.com/1adrianb/face-alignment.svg?branch=master)](https://travis-ci.com/1adrianb/face-alignment) [![Anaconda-Server Badge](https://anaconda.org/1adrianb/face_alignment/badges/version.svg)](https://anaconda.org/1adrianb/face_alignment) +[![PyPI](https://img.shields.io/pypi/v/nine.svg?style=flat-square)](https://pypi.org/project/face-alignment/) + +## Features + +#### Detect 2D facial landmarks in pictures + +

+ +

+ +```python +import face_alignment +from skimage import io + +fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False) + +input = io.imread('../test/assets/aflw-test.jpg') +preds = fa.get_landmarks(input) +``` + +#### Detect 3D facial landmarks in pictures + +

+ +

+ +```python +import face_alignment +from skimage import io + +fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, flip_input=False) + +input = io.imread('../test/assets/aflw-test.jpg') +preds = fa.get_landmarks(input) +``` + +#### Process an entire directory in one go + +```python +import face_alignment +from skimage import io + +fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False) + +preds = fa.get_landmarks_from_directory('../test/assets/') +``` + +#### Detect the landmarks using a specific face detector. + +By default the package will use the SFD face detector. However the users can alternatively use dlib or pre-existing ground truth bounding boxes. + +```python +import face_alignment + +# sfd for SFD, dlib for Dlib and folder for existing bounding boxes. +fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, face_detector='sfd') +``` + +#### Running on CPU/GPU +In order to specify the device (GPU or CPU) on which the code will run one can explicitly pass the device flag: + +```python +import face_alignment + +# cuda for CUDA +fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device='cpu') +``` + +Please also see the ``examples`` folder + +## Installation + +### Requirements + +* Python 3.5+ or Python 2.7 (it may work with other versions too) +* Linux, Windows or macOS +* pytorch (>=1.0) + +While not required, for optimal performance(especially for the detector) it is **highly** recommended to run the code using a CUDA enabled GPU. + +### Binaries + +The easiest way to install it is using either pip or conda: + +| **Using pip** | **Using conda** | +|------------------------------|--------------------------------------------| +| `pip install face-alignment` | `conda install -c 1adrianb face_alignment` | +| | | + +Alternatively, bellow, you can find instruction to build it from source. + +### From source + + Install pytorch and pytorch dependencies. Instructions taken from [pytorch readme](https://github.com/pytorch/pytorch). For a more updated version check the framework github page. + + On Linux +```bash +export CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" # [anaconda root directory] + +# Install basic dependencies +conda install numpy pyyaml mkl setuptools cmake gcc cffi + +# Add LAPACK support for the GPU +conda install -c soumith magma-cuda80 # or magma-cuda75 if CUDA 7.5 +``` + +On OSX +```bash +export CMAKE_PREFIX_PATH=[anaconda root directory] +conda install numpy pyyaml setuptools cmake cffi +``` +#### Get the PyTorch source +```bash +git clone --recursive https://github.com/pytorch/pytorch +``` + +#### Install PyTorch +On Linux +```bash +python setup.py install +``` + +On OSX +```bash +MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py install +``` + +#### Get the Face Alignment source code +```bash +git clone https://github.com/1adrianb/face-alignment +``` +#### Install the Face Alignment lib +```bash +pip install -r requirements.txt +python setup.py install +``` + +### Docker image + +A Dockerfile is provided to build images with cuda support and cudnn v5. For more instructions about running and building a docker image check the orginal Docker documentation. +``` +docker build -t face-alignment . +``` + +## How does it work? + +While here the work is presented as a black-box, if you want to know more about the intrisecs of the method please check the original paper either on arxiv or my [webpage](https://www.adrianbulat.com). + +## Contributions + +All contributions are welcomed. If you encounter any issue (including examples of images where it fails) feel free to open an issue. + +## Citation + +``` +@inproceedings{bulat2017far, + title={How far are we from solving the 2D \& 3D Face Alignment problem? (and a dataset of 230,000 3D facial landmarks)}, + author={Bulat, Adrian and Tzimiropoulos, Georgios}, + booktitle={International Conference on Computer Vision}, + year={2017} +} +``` + +For citing dlib, pytorch or any other packages used here please check the original page of their respective authors. + +## Acknowledgements + +* To the [pytorch](http://pytorch.org/) team for providing such an awesome deeplearning framework +* To [my supervisor](http://www.cs.nott.ac.uk/~pszyt/) for his patience and suggestions. +* To all other python developers that made available the rest of the packages used in this repository. diff --git a/FaceLandmarkDetection/docs/images/2dlandmarks.png b/FaceLandmarkDetection/docs/images/2dlandmarks.png new file mode 100644 index 0000000000000000000000000000000000000000..f815722b4b22066e21fda9210400e86d36019f8a Binary files /dev/null and b/FaceLandmarkDetection/docs/images/2dlandmarks.png differ diff --git a/FaceLandmarkDetection/docs/images/face-alignment-adrian.gif b/FaceLandmarkDetection/docs/images/face-alignment-adrian.gif new file mode 100644 index 0000000000000000000000000000000000000000..b6bb4224b922dfa04b7fa61805af65e56a9a08be Binary files /dev/null and b/FaceLandmarkDetection/docs/images/face-alignment-adrian.gif differ diff --git a/FaceLandmarkDetection/face_alignment/__init__.py b/FaceLandmarkDetection/face_alignment/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4bae29fd5f85b41e4669302bd2603bc6924eddc7 --- /dev/null +++ b/FaceLandmarkDetection/face_alignment/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +__author__ = """Adrian Bulat""" +__email__ = 'adrian.bulat@nottingham.ac.uk' +__version__ = '1.0.1' + +from .api import FaceAlignment, LandmarksType, NetworkSize diff --git a/FaceLandmarkDetection/face_alignment/__pycache__/__init__.cpython-36.pyc b/FaceLandmarkDetection/face_alignment/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26bc9e86be76e79e7a32a0c66f262f2581e87fb5 Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/__pycache__/__init__.cpython-36.pyc differ diff --git a/FaceLandmarkDetection/face_alignment/__pycache__/__init__.cpython-37.pyc b/FaceLandmarkDetection/face_alignment/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..079f72e8c64729757200ce1dcbb9f62e36db71de Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/__pycache__/__init__.cpython-37.pyc differ diff --git a/FaceLandmarkDetection/face_alignment/__pycache__/api.cpython-36.pyc b/FaceLandmarkDetection/face_alignment/__pycache__/api.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8b0efd010573c3ec4fd536afe90d65547b0cbdc Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/__pycache__/api.cpython-36.pyc differ diff --git a/FaceLandmarkDetection/face_alignment/__pycache__/api.cpython-37.pyc b/FaceLandmarkDetection/face_alignment/__pycache__/api.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf659733951be5a40c9a60d2fe54a8eb18bfb97f Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/__pycache__/api.cpython-37.pyc differ diff --git a/FaceLandmarkDetection/face_alignment/__pycache__/models.cpython-36.pyc b/FaceLandmarkDetection/face_alignment/__pycache__/models.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3b480edcd636d96fe427e73d5d1006bd1e91b36 Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/__pycache__/models.cpython-36.pyc differ diff --git a/FaceLandmarkDetection/face_alignment/__pycache__/models.cpython-37.pyc b/FaceLandmarkDetection/face_alignment/__pycache__/models.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3407599d10993ad2af893304eb58cc4e94456be Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/__pycache__/models.cpython-37.pyc differ diff --git a/FaceLandmarkDetection/face_alignment/__pycache__/utils.cpython-36.pyc b/FaceLandmarkDetection/face_alignment/__pycache__/utils.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b3f0648a7b85eea4c2d8eedcd5434a40238744b Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/__pycache__/utils.cpython-36.pyc differ diff --git a/FaceLandmarkDetection/face_alignment/__pycache__/utils.cpython-37.pyc b/FaceLandmarkDetection/face_alignment/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b85e82aaa020db2e651ee1f7ea67e71195a6f1c3 Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/__pycache__/utils.cpython-37.pyc differ diff --git a/FaceLandmarkDetection/face_alignment/api.py b/FaceLandmarkDetection/face_alignment/api.py new file mode 100644 index 0000000000000000000000000000000000000000..a8a09a3095d08ef20c46550662766000d13d507c --- /dev/null +++ b/FaceLandmarkDetection/face_alignment/api.py @@ -0,0 +1,207 @@ +from __future__ import print_function +import os +import torch +from torch.utils.model_zoo import load_url +from enum import Enum +from skimage import io +from skimage import color +import numpy as np +import cv2 +try: + import urllib.request as request_file +except BaseException: + import urllib as request_file + +from .models import FAN, ResNetDepth +from .utils import * + + +class LandmarksType(Enum): + """Enum class defining the type of landmarks to detect. + + ``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face + ``_2halfD`` - this points represent the projection of the 3D points into 3D + ``_3D`` - detect the points ``(x,y,z)``` in a 3D space + + """ + _2D = 1 + _2halfD = 2 + _3D = 3 + + +class NetworkSize(Enum): + # TINY = 1 + # SMALL = 2 + # MEDIUM = 3 + LARGE = 4 + + def __new__(cls, value): + member = object.__new__(cls) + member._value_ = value + return member + + def __int__(self): + return self.value + +models_urls = { + '2DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/2DFAN4-11f355bf06.pth.tar', + '3DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/3DFAN4-7835d9f11d.pth.tar', + 'depth': 'https://www.adrianbulat.com/downloads/python-fan/depth-2a464da4ea.pth.tar', +} + + +class FaceAlignment: + def __init__(self, landmarks_type, network_size=NetworkSize.LARGE, + device='cuda', flip_input=False, face_detector='sfd', verbose=False): + self.device = device + self.flip_input = flip_input + self.landmarks_type = landmarks_type + self.verbose = verbose + + network_size = int(network_size) + + if 'cuda' in device: + torch.backends.cudnn.benchmark = True + + # Get the face detector + face_detector_module = __import__('face_alignment.detection.' + face_detector, + globals(), locals(), [face_detector], 0) + self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose) + + # Initialise the face alignemnt networks + self.face_alignment_net = FAN(network_size) + if landmarks_type == LandmarksType._2D: + network_name = '2DFAN-' + str(network_size) + else: + network_name = '3DFAN-' + str(network_size) + + fan_weights = load_url(models_urls[network_name], map_location=lambda storage, loc: storage) + self.face_alignment_net.load_state_dict(fan_weights) + + self.face_alignment_net.to(device) + self.face_alignment_net.eval() + + # Initialiase the depth prediciton network + if landmarks_type == LandmarksType._3D: + self.depth_prediciton_net = ResNetDepth() + + depth_weights = load_url(models_urls['depth'], map_location=lambda storage, loc: storage) + depth_dict = { + k.replace('module.', ''): v for k, + v in depth_weights['state_dict'].items()} + self.depth_prediciton_net.load_state_dict(depth_dict) + + self.depth_prediciton_net.to(device) + self.depth_prediciton_net.eval() + + def get_landmarks(self, image_or_path, detected_faces=None): + """Deprecated, please use get_landmarks_from_image + + Arguments: + image_or_path {string or numpy.array or torch.tensor} -- The input image or path to it. + + Keyword Arguments: + detected_faces {list of numpy.array} -- list of bounding boxes, one for each face found + in the image (default: {None}) + """ + return self.get_landmarks_from_image(image_or_path, detected_faces) + + def get_landmarks_from_image(self, image_or_path, detected_faces=None): + """Predict the landmarks for each face present in the image. + + This function predicts a set of 68 2D or 3D images, one for each image present. + If detect_faces is None the method will also run a face detector. + + Arguments: + image_or_path {string or numpy.array or torch.tensor} -- The input image or path to it. + + Keyword Arguments: + detected_faces {list of numpy.array} -- list of bounding boxes, one for each face found + in the image (default: {None}) + """ + if isinstance(image_or_path, str): + try: + image = io.imread(image_or_path) + except IOError: + print("error opening file :: ", image_or_path) + return None + else: + image = image_or_path + + if image.ndim == 2: + image = color.gray2rgb(image) + elif image.ndim == 4: + image = image[..., :3] + + if detected_faces is None: + detected_faces = self.face_detector.detect_from_image(image[..., ::-1].copy()) + + if len(detected_faces) == 0: + print("Warning: No faces were detected.") + return None + + torch.set_grad_enabled(False) + landmarks = [] + for i, d in enumerate(detected_faces): + center = torch.FloatTensor( + [d[2] - (d[2] - d[0]) / 2.0, d[3] - (d[3] - d[1]) / 2.0]) + center[1] = center[1] - (d[3] - d[1]) * 0.12 + scale = (d[2] - d[0] + d[3] - d[1]) / self.face_detector.reference_scale + + inp = crop(image, center, scale) + inp = torch.from_numpy(inp.transpose( + (2, 0, 1))).float() + + inp = inp.to(self.device) + inp.div_(255.0).unsqueeze_(0) + + out = self.face_alignment_net(inp)[-1].detach() + if self.flip_input: + out += flip(self.face_alignment_net(flip(inp)) + [-1].detach(), is_label=True) + out = out.cpu() + + pts, pts_img = get_preds_fromhm(out, center, scale) + pts, pts_img = pts.view(68, 2) * 4, pts_img.view(68, 2) + + if self.landmarks_type == LandmarksType._3D: + heatmaps = np.zeros((68, 256, 256), dtype=np.float32) + for i in range(68): + if pts[i, 0] > 0: + heatmaps[i] = draw_gaussian( + heatmaps[i], pts[i], 2) + heatmaps = torch.from_numpy( + heatmaps).unsqueeze_(0) + + heatmaps = heatmaps.to(self.device) + depth_pred = self.depth_prediciton_net( + torch.cat((inp, heatmaps), 1)).data.cpu().view(68, 1) + pts_img = torch.cat( + (pts_img, depth_pred * (1.0 / (256.0 / (200.0 * scale)))), 1) + + landmarks.append(pts_img.numpy()) + + return landmarks + + def get_landmarks_from_directory(self, path, extensions=['.jpg', '.png'], recursive=True, show_progress_bar=True): + detected_faces = self.face_detector.detect_from_directory(path, extensions, recursive, show_progress_bar) + + predictions = {} + for image_path, bounding_boxes in detected_faces.items(): + image = io.imread(image_path) + preds = self.get_landmarks_from_image(image, bounding_boxes) + predictions[image_path] = preds + + return predictions + + @staticmethod + def remove_models(self): + base_path = os.path.join(appdata_dir('face_alignment'), "data") + for data_model in os.listdir(base_path): + file_path = os.path.join(base_path, data_model) + try: + if os.path.isfile(file_path): + print('Removing ' + data_model + ' ...') + os.unlink(file_path) + except Exception as e: + print(e) diff --git a/FaceLandmarkDetection/face_alignment/detection/__init__.py b/FaceLandmarkDetection/face_alignment/detection/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1a6b0402dae864a3cc5dc2a90a412fd842a0efc7 --- /dev/null +++ b/FaceLandmarkDetection/face_alignment/detection/__init__.py @@ -0,0 +1 @@ +from .core import FaceDetector \ No newline at end of file diff --git a/FaceLandmarkDetection/face_alignment/detection/__pycache__/__init__.cpython-36.pyc b/FaceLandmarkDetection/face_alignment/detection/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09ea4818d422fd3267003f907aad6a7b79fd829e Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/__pycache__/__init__.cpython-36.pyc differ diff --git a/FaceLandmarkDetection/face_alignment/detection/__pycache__/__init__.cpython-37.pyc b/FaceLandmarkDetection/face_alignment/detection/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..605c422babf01756a23b39a4b869b24896a17586 Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/__pycache__/__init__.cpython-37.pyc differ diff --git a/FaceLandmarkDetection/face_alignment/detection/__pycache__/core.cpython-36.pyc b/FaceLandmarkDetection/face_alignment/detection/__pycache__/core.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93d2a9d3998cf9c2facbbf5b6731d067c7193967 Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/__pycache__/core.cpython-36.pyc differ diff --git a/FaceLandmarkDetection/face_alignment/detection/__pycache__/core.cpython-37.pyc b/FaceLandmarkDetection/face_alignment/detection/__pycache__/core.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95657db16b78133d1e8f86160cb96fe371b5ba16 Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/__pycache__/core.cpython-37.pyc differ diff --git a/FaceLandmarkDetection/face_alignment/detection/core.py b/FaceLandmarkDetection/face_alignment/detection/core.py new file mode 100644 index 0000000000000000000000000000000000000000..47a3f9d5fcf461843b2aafb75031f06be591c7dd --- /dev/null +++ b/FaceLandmarkDetection/face_alignment/detection/core.py @@ -0,0 +1,131 @@ +import logging +import glob +from tqdm import tqdm +import numpy as np +import torch +import cv2 +from skimage import io + + +class FaceDetector(object): + """An abstract class representing a face detector. + + Any other face detection implementation must subclass it. All subclasses + must implement ``detect_from_image``, that return a list of detected + bounding boxes. Optionally, for speed considerations detect from path is + recommended. + """ + + def __init__(self, device, verbose): + self.device = device + self.verbose = verbose + + if verbose: + if 'cpu' in device: + logger = logging.getLogger(__name__) + logger.warning("Detection running on CPU, this may be potentially slow.") + + if 'cpu' not in device and 'cuda' not in device: + if verbose: + logger.error("Expected values for device are: {cpu, cuda} but got: %s", device) + raise ValueError + + def detect_from_image(self, tensor_or_path): + """Detects faces in a given image. + + This function detects the faces present in a provided BGR(usually) + image. The input can be either the image itself or the path to it. + + Arguments: + tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path + to an image or the image itself. + + Example:: + + >>> path_to_image = 'data/image_01.jpg' + ... detected_faces = detect_from_image(path_to_image) + [A list of bounding boxes (x1, y1, x2, y2)] + >>> image = cv2.imread(path_to_image) + ... detected_faces = detect_from_image(image) + [A list of bounding boxes (x1, y1, x2, y2)] + + """ + raise NotImplementedError + + def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True): + """Detects faces from all the images present in a given directory. + + Arguments: + path {string} -- a string containing a path that points to the folder containing the images + + Keyword Arguments: + extensions {list} -- list of string containing the extensions to be + consider in the following format: ``.extension_name`` (default: + {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the + folder recursively (default: {False}) show_progress_bar {bool} -- + display a progressbar (default: {True}) + + Example: + >>> directory = 'data' + ... detected_faces = detect_from_directory(directory) + {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]} + + """ + if self.verbose: + logger = logging.getLogger(__name__) + + if len(extensions) == 0: + if self.verbose: + logger.error("Expected at list one extension, but none was received.") + raise ValueError + + if self.verbose: + logger.info("Constructing the list of images.") + additional_pattern = '/**/*' if recursive else '/*' + files = [] + for extension in extensions: + files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive)) + + if self.verbose: + logger.info("Finished searching for images. %s images found", len(files)) + logger.info("Preparing to run the detection.") + + predictions = {} + for image_path in tqdm(files, disable=not show_progress_bar): + if self.verbose: + logger.info("Running the face detector on image: %s", image_path) + predictions[image_path] = self.detect_from_image(image_path) + + if self.verbose: + logger.info("The detector was successfully run on all %s images", len(files)) + + return predictions + + @property + def reference_scale(self): + raise NotImplementedError + + @property + def reference_x_shift(self): + raise NotImplementedError + + @property + def reference_y_shift(self): + raise NotImplementedError + + @staticmethod + def tensor_or_path_to_ndarray(tensor_or_path, rgb=True): + """Convert path (represented as a string) or torch.tensor to a numpy.ndarray + + Arguments: + tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself + """ + if isinstance(tensor_or_path, str): + return cv2.imread(tensor_or_path) if not rgb else io.imread(tensor_or_path) + elif torch.is_tensor(tensor_or_path): + # Call cpu in case its coming from cuda + return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy() + elif isinstance(tensor_or_path, np.ndarray): + return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path + else: + raise TypeError diff --git a/FaceLandmarkDetection/face_alignment/detection/dlib/__init__.py b/FaceLandmarkDetection/face_alignment/detection/dlib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d8e5ee5fd5b6dd145f3e3ec65c02d2d8befaef59 --- /dev/null +++ b/FaceLandmarkDetection/face_alignment/detection/dlib/__init__.py @@ -0,0 +1 @@ +from .dlib_detector import DlibDetector as FaceDetector \ No newline at end of file diff --git a/FaceLandmarkDetection/face_alignment/detection/dlib/dlib_detector.py b/FaceLandmarkDetection/face_alignment/detection/dlib/dlib_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..6cc8368aef0203753fad4637b4e3981281977fe6 --- /dev/null +++ b/FaceLandmarkDetection/face_alignment/detection/dlib/dlib_detector.py @@ -0,0 +1,68 @@ +import os +import cv2 +import dlib + +try: + import urllib.request as request_file +except BaseException: + import urllib as request_file + +from ..core import FaceDetector +from ...utils import appdata_dir + + +class DlibDetector(FaceDetector): + def __init__(self, device, path_to_detector=None, verbose=False): + super().__init__(device, verbose) + + print('Warning: this detector is deprecated. Please use a different one, i.e.: S3FD.') + base_path = os.path.join(appdata_dir('face_alignment'), "data") + + # Initialise the face detector + if 'cuda' in device: + if path_to_detector is None: + path_to_detector = os.path.join( + base_path, "mmod_human_face_detector.dat") + + if not os.path.isfile(path_to_detector): + print("Downloading the face detection CNN. Please wait...") + + path_to_temp_detector = os.path.join( + base_path, "mmod_human_face_detector.dat.download") + + if os.path.isfile(path_to_temp_detector): + os.remove(os.path.join(path_to_temp_detector)) + + request_file.urlretrieve( + "https://www.adrianbulat.com/downloads/dlib/mmod_human_face_detector.dat", + os.path.join(path_to_temp_detector)) + + os.rename(os.path.join(path_to_temp_detector), os.path.join(path_to_detector)) + + self.face_detector = dlib.cnn_face_detection_model_v1(path_to_detector) + else: + self.face_detector = dlib.get_frontal_face_detector() + + def detect_from_image(self, tensor_or_path): + image = self.tensor_or_path_to_ndarray(tensor_or_path, rgb=False) + + detected_faces = self.face_detector(cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)) + + if 'cuda' not in self.device: + detected_faces = [[d.left(), d.top(), d.right(), d.bottom()] for d in detected_faces] + else: + detected_faces = [[d.rect.left(), d.rect.top(), d.rect.right(), d.rect.bottom()] for d in detected_faces] + + return detected_faces + + @property + def reference_scale(self): + return 195 + + @property + def reference_x_shift(self): + return 0 + + @property + def reference_y_shift(self): + return 0 diff --git a/FaceLandmarkDetection/face_alignment/detection/folder/__init__.py b/FaceLandmarkDetection/face_alignment/detection/folder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0a9128ee0e13eb6f0058dbd480046aee50be336f --- /dev/null +++ b/FaceLandmarkDetection/face_alignment/detection/folder/__init__.py @@ -0,0 +1 @@ +from .folder_detector import FolderDetector as FaceDetector \ No newline at end of file diff --git a/FaceLandmarkDetection/face_alignment/detection/folder/folder_detector.py b/FaceLandmarkDetection/face_alignment/detection/folder/folder_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..add19fa4e4b933619b75d1dce441e3071b235191 --- /dev/null +++ b/FaceLandmarkDetection/face_alignment/detection/folder/folder_detector.py @@ -0,0 +1,53 @@ +import os +import numpy as np +import torch + +from ..core import FaceDetector + + +class FolderDetector(FaceDetector): + '''This is a simple helper module that assumes the faces were detected already + (either previously or are provided as ground truth). + + The class expects to find the bounding boxes in the same format used by + the rest of face detectors, mainly ``list[(x1,y1,x2,y2),...]``. + For each image the detector will search for a file with the same name and with one of the + following extensions: .npy, .t7 or .pth + + ''' + + def __init__(self, device, path_to_detector=None, verbose=False): + super(FolderDetector, self).__init__(device, verbose) + + def detect_from_image(self, tensor_or_path): + # Only strings supported + if not isinstance(tensor_or_path, str): + raise ValueError + + base_name = os.path.splitext(tensor_or_path)[0] + + if os.path.isfile(base_name + '.npy'): + detected_faces = np.load(base_name + '.npy') + elif os.path.isfile(base_name + '.t7'): + detected_faces = torch.load(base_name + '.t7') + elif os.path.isfile(base_name + '.pth'): + detected_faces = torch.load(base_name + '.pth') + else: + raise FileNotFoundError + + if not isinstance(detected_faces, list): + raise TypeError + + return detected_faces + + @property + def reference_scale(self): + return 195 + + @property + def reference_x_shift(self): + return 0 + + @property + def reference_y_shift(self): + return 0 diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/__init__.py b/FaceLandmarkDetection/face_alignment/detection/sfd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5a63ecd45658f22e66c171ada751fb33764d4559 --- /dev/null +++ b/FaceLandmarkDetection/face_alignment/detection/sfd/__init__.py @@ -0,0 +1 @@ +from .sfd_detector import SFDDetector as FaceDetector \ No newline at end of file diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/__init__.cpython-36.pyc b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..223f1c8c50f26c44013f12edbb297b3f0822d817 Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/__init__.cpython-36.pyc differ diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/__init__.cpython-37.pyc b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71478176e62df158a6458d825eedf176cff1cb16 Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/__init__.cpython-37.pyc differ diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/bbox.cpython-36.pyc b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/bbox.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06a77197d502544882d7676b71e89525fab5eaab Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/bbox.cpython-36.pyc differ diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/bbox.cpython-37.pyc b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/bbox.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..315d9b9b15d5b094b864f7a4a40c863c5276af0e Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/bbox.cpython-37.pyc differ diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/detect.cpython-36.pyc b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/detect.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37daedbfb169a7bf35c91a6d9abbf44d22c4c315 Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/detect.cpython-36.pyc differ diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/detect.cpython-37.pyc b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/detect.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5954c1300f601d431745d74ef3cfded16b6980d Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/detect.cpython-37.pyc differ diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/net_s3fd.cpython-36.pyc b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/net_s3fd.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0984c1e8a9c2bb8b78fc70c10f6fbc1f6cafc1aa Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/net_s3fd.cpython-36.pyc differ diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/net_s3fd.cpython-37.pyc b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/net_s3fd.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bff6f907dbf2c33fdaa0062748bb28a76e8b65ad Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/net_s3fd.cpython-37.pyc differ diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/sfd_detector.cpython-36.pyc b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/sfd_detector.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1b5f8b065f4da5d65f432a774e5b08620da4723 Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/sfd_detector.cpython-36.pyc differ diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/sfd_detector.cpython-37.pyc b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/sfd_detector.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33f1ccaeb437d82a166b0fb1603628525796c02f Binary files /dev/null and b/FaceLandmarkDetection/face_alignment/detection/sfd/__pycache__/sfd_detector.cpython-37.pyc differ diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/bbox.py b/FaceLandmarkDetection/face_alignment/detection/sfd/bbox.py new file mode 100644 index 0000000000000000000000000000000000000000..ca662365044608b992e7f15794bd2518b0781d7e --- /dev/null +++ b/FaceLandmarkDetection/face_alignment/detection/sfd/bbox.py @@ -0,0 +1,109 @@ +from __future__ import print_function +import os +import sys +import cv2 +import random +import datetime +import time +import math +import argparse +import numpy as np +import torch + +try: + from iou import IOU +except BaseException: + # IOU cython speedup 10x + def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2): + sa = abs((ax2 - ax1) * (ay2 - ay1)) + sb = abs((bx2 - bx1) * (by2 - by1)) + x1, y1 = max(ax1, bx1), max(ay1, by1) + x2, y2 = min(ax2, bx2), min(ay2, by2) + w = x2 - x1 + h = y2 - y1 + if w < 0 or h < 0: + return 0.0 + else: + return 1.0 * w * h / (sa + sb - w * h) + + +def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh): + xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1 + dx, dy = (xc - axc) / aww, (yc - ayc) / ahh + dw, dh = math.log(ww / aww), math.log(hh / ahh) + return dx, dy, dw, dh + + +def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh): + xc, yc = dx * aww + axc, dy * ahh + ayc + ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh + x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2 + return x1, y1, x2, y2 + + +def nms(dets, thresh): + if 0 == len(dets): + return [] + x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4] + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]]) + xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]]) + + w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1) + ovr = w * h / (areas[i] + areas[order[1:]] - w * h) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep + + +def encode(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 4]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded boxes (tensor), Shape: [num_priors, 4] + """ + + # dist b/t match center and prior's center + g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] + # encode variance + g_cxcy /= (variances[0] * priors[:, 2:]) + # match wh / prior wh + g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] + g_wh = torch.log(g_wh) / variances[1] + # return target for smooth_l1_loss + return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] + + +def decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat(( + priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/detect.py b/FaceLandmarkDetection/face_alignment/detection/sfd/detect.py new file mode 100644 index 0000000000000000000000000000000000000000..84d120de15145a830938a2edcf9cf8d29bc59f72 --- /dev/null +++ b/FaceLandmarkDetection/face_alignment/detection/sfd/detect.py @@ -0,0 +1,75 @@ +import torch +import torch.nn.functional as F + +import os +import sys +import cv2 +import random +import datetime +import math +import argparse +import numpy as np + +import scipy.io as sio +import zipfile +from .net_s3fd import s3fd +from .bbox import * + + +def detect(net, img, device): + img = img - np.array([104, 117, 123]) + img = img.transpose(2, 0, 1) + img = img.reshape((1,) + img.shape) + + if 'cuda' in device: + torch.backends.cudnn.benchmark = True + + img = torch.from_numpy(img).float().to(device) + BB, CC, HH, WW = img.size() + with torch.no_grad(): + olist = net(img) + + bboxlist = [] + for i in range(len(olist) // 2): + olist[i * 2] = F.softmax(olist[i * 2], dim=1) + olist = [oelem.data.cpu() for oelem in olist] + for i in range(len(olist) // 2): + ocls, oreg = olist[i * 2], olist[i * 2 + 1] + FB, FC, FH, FW = ocls.size() # feature map size + stride = 2**(i + 2) # 4,8,16,32,64,128 + anchor = stride * 4 + poss = zip(*np.where(ocls[:, 1, :, :] > 0.05)) + for Iindex, hindex, windex in poss: + axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride + score = ocls[0, 1, hindex, windex] + loc = oreg[0, :, hindex, windex].contiguous().view(1, 4) + priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]) + variances = [0.1, 0.2] + box = decode(loc, priors, variances) + x1, y1, x2, y2 = box[0] * 1.0 + # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1) + bboxlist.append([x1, y1, x2, y2, score]) + bboxlist = np.array(bboxlist) + if 0 == len(bboxlist): + bboxlist = np.zeros((1, 5)) + + return bboxlist + + +def flip_detect(net, img, device): + img = cv2.flip(img, 1) + b = detect(net, img, device) + + bboxlist = np.zeros(b.shape) + bboxlist[:, 0] = img.shape[1] - b[:, 2] + bboxlist[:, 1] = b[:, 1] + bboxlist[:, 2] = img.shape[1] - b[:, 0] + bboxlist[:, 3] = b[:, 3] + bboxlist[:, 4] = b[:, 4] + return bboxlist + + +def pts_to_bb(pts): + min_x, min_y = np.min(pts, axis=0) + max_x, max_y = np.max(pts, axis=0) + return np.array([min_x, min_y, max_x, max_y]) diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/net_s3fd.py b/FaceLandmarkDetection/face_alignment/detection/sfd/net_s3fd.py new file mode 100644 index 0000000000000000000000000000000000000000..fc64313c277ab594d0257585c70f147606693452 --- /dev/null +++ b/FaceLandmarkDetection/face_alignment/detection/sfd/net_s3fd.py @@ -0,0 +1,129 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class L2Norm(nn.Module): + def __init__(self, n_channels, scale=1.0): + super(L2Norm, self).__init__() + self.n_channels = n_channels + self.scale = scale + self.eps = 1e-10 + self.weight = nn.Parameter(torch.Tensor(self.n_channels)) + self.weight.data *= 0.0 + self.weight.data += self.scale + + def forward(self, x): + norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps + x = x / norm * self.weight.view(1, -1, 1, 1) + return x + + +class s3fd(nn.Module): + def __init__(self): + super(s3fd, self).__init__() + self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) + self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + + self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) + self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) + + self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) + self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + + self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) + self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3) + self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0) + + self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0) + self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1) + + self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0) + self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) + + self.conv3_3_norm = L2Norm(256, scale=10) + self.conv4_3_norm = L2Norm(512, scale=8) + self.conv5_3_norm = L2Norm(512, scale=5) + + self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) + self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) + self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) + self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) + self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + + self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1) + self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1) + self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) + self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1) + self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + h = F.relu(self.conv1_1(x)) + h = F.relu(self.conv1_2(h)) + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv2_1(h)) + h = F.relu(self.conv2_2(h)) + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv3_1(h)) + h = F.relu(self.conv3_2(h)) + h = F.relu(self.conv3_3(h)) + f3_3 = h + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv4_1(h)) + h = F.relu(self.conv4_2(h)) + h = F.relu(self.conv4_3(h)) + f4_3 = h + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv5_1(h)) + h = F.relu(self.conv5_2(h)) + h = F.relu(self.conv5_3(h)) + f5_3 = h + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.fc6(h)) + h = F.relu(self.fc7(h)) + ffc7 = h + h = F.relu(self.conv6_1(h)) + h = F.relu(self.conv6_2(h)) + f6_2 = h + h = F.relu(self.conv7_1(h)) + h = F.relu(self.conv7_2(h)) + f7_2 = h + + f3_3 = self.conv3_3_norm(f3_3) + f4_3 = self.conv4_3_norm(f4_3) + f5_3 = self.conv5_3_norm(f5_3) + + cls1 = self.conv3_3_norm_mbox_conf(f3_3) + reg1 = self.conv3_3_norm_mbox_loc(f3_3) + cls2 = self.conv4_3_norm_mbox_conf(f4_3) + reg2 = self.conv4_3_norm_mbox_loc(f4_3) + cls3 = self.conv5_3_norm_mbox_conf(f5_3) + reg3 = self.conv5_3_norm_mbox_loc(f5_3) + cls4 = self.fc7_mbox_conf(ffc7) + reg4 = self.fc7_mbox_loc(ffc7) + cls5 = self.conv6_2_mbox_conf(f6_2) + reg5 = self.conv6_2_mbox_loc(f6_2) + cls6 = self.conv7_2_mbox_conf(f7_2) + reg6 = self.conv7_2_mbox_loc(f7_2) + + # max-out background label + chunk = torch.chunk(cls1, 4, 1) + bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2]) + cls1 = torch.cat([bmax, chunk[3]], dim=1) + + return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6] diff --git a/FaceLandmarkDetection/face_alignment/detection/sfd/sfd_detector.py b/FaceLandmarkDetection/face_alignment/detection/sfd/sfd_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..29c7768558b83dceea8f94e4d859b71dfa54cc85 --- /dev/null +++ b/FaceLandmarkDetection/face_alignment/detection/sfd/sfd_detector.py @@ -0,0 +1,51 @@ +import os +import cv2 +from torch.utils.model_zoo import load_url + +from ..core import FaceDetector + +from .net_s3fd import s3fd +from .bbox import * +from .detect import * + +models_urls = { + 's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth', +} + + +class SFDDetector(FaceDetector): + def __init__(self, device, path_to_detector=None, verbose=False): + super(SFDDetector, self).__init__(device, verbose) + + # Initialise the face detector + if path_to_detector is None: + model_weights = load_url(models_urls['s3fd']) + else: + model_weights = torch.load(path_to_detector) + + self.face_detector = s3fd() + self.face_detector.load_state_dict(model_weights) + self.face_detector.to(device) + self.face_detector.eval() + + def detect_from_image(self, tensor_or_path): + image = self.tensor_or_path_to_ndarray(tensor_or_path) + + bboxlist = detect(self.face_detector, image, device=self.device) + keep = nms(bboxlist, 0.3) + bboxlist = bboxlist[keep, :] + bboxlist = [x for x in bboxlist if x[-1] > 0.5] + + return bboxlist + + @property + def reference_scale(self): + return 195 + + @property + def reference_x_shift(self): + return 0 + + @property + def reference_y_shift(self): + return 0 diff --git a/FaceLandmarkDetection/face_alignment/models.py b/FaceLandmarkDetection/face_alignment/models.py new file mode 100644 index 0000000000000000000000000000000000000000..ee2dde32bdf72c25a4600e48efa73ffc0d4a3893 --- /dev/null +++ b/FaceLandmarkDetection/face_alignment/models.py @@ -0,0 +1,261 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, + stride=strd, padding=padding, bias=bias) + + +class ConvBlock(nn.Module): + def __init__(self, in_planes, out_planes): + super(ConvBlock, self).__init__() + self.bn1 = nn.BatchNorm2d(in_planes) + self.conv1 = conv3x3(in_planes, int(out_planes / 2)) + self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) + self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4)) + self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) + self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4)) + + if in_planes != out_planes: + self.downsample = nn.Sequential( + nn.BatchNorm2d(in_planes), + nn.ReLU(True), + nn.Conv2d(in_planes, out_planes, + kernel_size=1, stride=1, bias=False), + ) + else: + self.downsample = None + + def forward(self, x): + residual = x + + out1 = self.bn1(x) + out1 = F.relu(out1, True) + out1 = self.conv1(out1) + + out2 = self.bn2(out1) + out2 = F.relu(out2, True) + out2 = self.conv2(out2) + + out3 = self.bn3(out2) + out3 = F.relu(out3, True) + out3 = self.conv3(out3) + + out3 = torch.cat((out1, out2, out3), 1) + + if self.downsample is not None: + residual = self.downsample(residual) + + out3 += residual + + return out3 + + +class Bottleneck(nn.Module): + + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HourGlass(nn.Module): + def __init__(self, num_modules, depth, num_features): + super(HourGlass, self).__init__() + self.num_modules = num_modules + self.depth = depth + self.features = num_features + + self._generate_network(self.depth) + + def _generate_network(self, level): + self.add_module('b1_' + str(level), ConvBlock(self.features, self.features)) + + self.add_module('b2_' + str(level), ConvBlock(self.features, self.features)) + + if level > 1: + self._generate_network(level - 1) + else: + self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features)) + + self.add_module('b3_' + str(level), ConvBlock(self.features, self.features)) + + def _forward(self, level, inp): + # Upper branch + up1 = inp + up1 = self._modules['b1_' + str(level)](up1) + + # Lower branch + low1 = F.avg_pool2d(inp, 2, stride=2) + low1 = self._modules['b2_' + str(level)](low1) + + if level > 1: + low2 = self._forward(level - 1, low1) + else: + low2 = low1 + low2 = self._modules['b2_plus_' + str(level)](low2) + + low3 = low2 + low3 = self._modules['b3_' + str(level)](low3) + + up2 = F.interpolate(low3, scale_factor=2, mode='nearest') + + return up1 + up2 + + def forward(self, x): + return self._forward(self.depth, x) + + +class FAN(nn.Module): + + def __init__(self, num_modules=1): + super(FAN, self).__init__() + self.num_modules = num_modules + + # Base part + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.bn1 = nn.BatchNorm2d(64) + self.conv2 = ConvBlock(64, 128) + self.conv3 = ConvBlock(128, 128) + self.conv4 = ConvBlock(128, 256) + + # Stacking part + for hg_module in range(self.num_modules): + self.add_module('m' + str(hg_module), HourGlass(1, 4, 256)) + self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256)) + self.add_module('conv_last' + str(hg_module), + nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) + self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256)) + self.add_module('l' + str(hg_module), nn.Conv2d(256, + 68, kernel_size=1, stride=1, padding=0)) + + if hg_module < self.num_modules - 1: + self.add_module( + 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) + self.add_module('al' + str(hg_module), nn.Conv2d(68, + 256, kernel_size=1, stride=1, padding=0)) + + def forward(self, x): + x = F.relu(self.bn1(self.conv1(x)), True) + x = F.avg_pool2d(self.conv2(x), 2, stride=2) + x = self.conv3(x) + x = self.conv4(x) + + previous = x + + outputs = [] + for i in range(self.num_modules): + hg = self._modules['m' + str(i)](previous) + + ll = hg + ll = self._modules['top_m_' + str(i)](ll) + + ll = F.relu(self._modules['bn_end' + str(i)] + (self._modules['conv_last' + str(i)](ll)), True) + + # Predict heatmaps + tmp_out = self._modules['l' + str(i)](ll) + outputs.append(tmp_out) + + if i < self.num_modules - 1: + ll = self._modules['bl' + str(i)](ll) + tmp_out_ = self._modules['al' + str(i)](tmp_out) + previous = previous + ll + tmp_out_ + + return outputs + + +class ResNetDepth(nn.Module): + + def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68): + self.inplanes = 64 + super(ResNetDepth, self).__init__() + self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x diff --git a/FaceLandmarkDetection/face_alignment/utils.py b/FaceLandmarkDetection/face_alignment/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..619570d580d65c84ba60ee043ad381a850f1ab26 --- /dev/null +++ b/FaceLandmarkDetection/face_alignment/utils.py @@ -0,0 +1,274 @@ +from __future__ import print_function +import os +import sys +import time +import torch +import math +import numpy as np +import cv2 + + +def _gaussian( + size=3, sigma=0.25, amplitude=1, normalize=False, width=None, + height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5, + mean_vert=0.5): + # handle some defaults + if width is None: + width = size + if height is None: + height = size + if sigma_horz is None: + sigma_horz = sigma + if sigma_vert is None: + sigma_vert = sigma + center_x = mean_horz * width + 0.5 + center_y = mean_vert * height + 0.5 + gauss = np.empty((height, width), dtype=np.float32) + # generate kernel + for i in range(height): + for j in range(width): + gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / ( + sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0)) + if normalize: + gauss = gauss / np.sum(gauss) + return gauss + + +def draw_gaussian(image, point, sigma): + # Check if the gaussian is inside + ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)] + br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)] + if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1): + return image + size = 6 * sigma + 1 + g = _gaussian(size) + g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))] + g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))] + img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))] + img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))] + assert (g_x[0] > 0 and g_y[1] > 0) + image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1] + ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]] + image[image > 1] = 1 + return image + + +def transform(point, center, scale, resolution, invert=False): + """Generate and affine transformation matrix. + + Given a set of points, a center, a scale and a targer resolution, the + function generates and affine transformation matrix. If invert is ``True`` + it will produce the inverse transformation. + + Arguments: + point {torch.tensor} -- the input 2D point + center {torch.tensor or numpy.array} -- the center around which to perform the transformations + scale {float} -- the scale of the face/object + resolution {float} -- the output resolution + + Keyword Arguments: + invert {bool} -- define wherever the function should produce the direct or the + inverse transformation matrix (default: {False}) + """ + _pt = torch.ones(3) + _pt[0] = point[0] + _pt[1] = point[1] + + h = 200.0 * scale + t = torch.eye(3) + t[0, 0] = resolution / h + t[1, 1] = resolution / h + t[0, 2] = resolution * (-center[0] / h + 0.5) + t[1, 2] = resolution * (-center[1] / h + 0.5) + + if invert: + t = torch.inverse(t) + + new_point = (torch.matmul(t, _pt))[0:2] + + return new_point.int() + + +def crop(image, center, scale, resolution=256.0): + """Center crops an image or set of heatmaps + + Arguments: + image {numpy.array} -- an rgb image + center {numpy.array} -- the center of the object, usually the same as of the bounding box + scale {float} -- scale of the face + + Keyword Arguments: + resolution {float} -- the size of the output cropped image (default: {256.0}) + + Returns: + [type] -- [description] + """ # Crop around the center point + """ Crops the image around the center. Input is expected to be an np.ndarray """ + ul = transform([1, 1], center, scale, resolution, True) + br = transform([resolution, resolution], center, scale, resolution, True) + # pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0) + if image.ndim > 2: + newDim = np.array([br[1] - ul[1], br[0] - ul[0], + image.shape[2]], dtype=np.int32) + newImg = np.zeros(newDim, dtype=np.uint8) + else: + newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int) + newImg = np.zeros(newDim, dtype=np.uint8) + ht = image.shape[0] + wd = image.shape[1] + newX = np.array( + [max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32) + newY = np.array( + [max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32) + oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32) + oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32) + newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1] + ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :] + newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)), + interpolation=cv2.INTER_LINEAR) + return newImg + + +def get_preds_fromhm(hm, center=None, scale=None): + """Obtain (x,y) coordinates given a set of N heatmaps. If the center + and the scale is provided the function will return the points also in + the original coordinate frame. + + Arguments: + hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H] + + Keyword Arguments: + center {torch.tensor} -- the center of the bounding box (default: {None}) + scale {float} -- face scale (default: {None}) + """ + max, idx = torch.max( + hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2) + idx += 1 + preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float() + preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1) + preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1) + + for i in range(preds.size(0)): + for j in range(preds.size(1)): + hm_ = hm[i, j, :] + pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1 + if pX > 0 and pX < 63 and pY > 0 and pY < 63: + diff = torch.FloatTensor( + [hm_[pY, pX + 1] - hm_[pY, pX - 1], + hm_[pY + 1, pX] - hm_[pY - 1, pX]]) + preds[i, j].add_(diff.sign_().mul_(.25)) + + preds.add_(-.5) + + preds_orig = torch.zeros(preds.size()) + if center is not None and scale is not None: + for i in range(hm.size(0)): + for j in range(hm.size(1)): + preds_orig[i, j] = transform( + preds[i, j], center, scale, hm.size(2), True) + + return preds, preds_orig + + +def shuffle_lr(parts, pairs=None): + """Shuffle the points left-right according to the axis of symmetry + of the object. + + Arguments: + parts {torch.tensor} -- a 3D or 4D object containing the + heatmaps. + + Keyword Arguments: + pairs {list of integers} -- [order of the flipped points] (default: {None}) + """ + if pairs is None: + pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, + 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35, + 34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41, + 40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63, + 62, 61, 60, 67, 66, 65] + if parts.ndimension() == 3: + parts = parts[pairs, ...] + else: + parts = parts[:, pairs, ...] + + return parts + + +def flip(tensor, is_label=False): + """Flip an image or a set of heatmaps left-right + + Arguments: + tensor {numpy.array or torch.tensor} -- [the input image or heatmaps] + + Keyword Arguments: + is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False}) + """ + if not torch.is_tensor(tensor): + tensor = torch.from_numpy(tensor) + + if is_label: + tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1) + else: + tensor = tensor.flip(tensor.ndimension() - 1) + + return tensor + +# From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py) + + +def appdata_dir(appname=None, roaming=False): + """ appdata_dir(appname=None, roaming=False) + + Get the path to the application directory, where applications are allowed + to write user specific files (e.g. configurations). For non-user specific + data, consider using common_appdata_dir(). + If appname is given, a subdir is appended (and created if necessary). + If roaming is True, will prefer a roaming directory (Windows Vista/7). + """ + + # Define default user directory + userDir = os.getenv('FACEALIGNMENT_USERDIR', None) + if userDir is None: + userDir = os.path.expanduser('~') + if not os.path.isdir(userDir): # pragma: no cover + userDir = '/var/tmp' # issue #54 + + # Get system app data dir + path = None + if sys.platform.startswith('win'): + path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA') + path = (path2 or path1) if roaming else (path1 or path2) + elif sys.platform.startswith('darwin'): + path = os.path.join(userDir, 'Library', 'Application Support') + # On Linux and as fallback + if not (path and os.path.isdir(path)): + path = userDir + + # Maybe we should store things local to the executable (in case of a + # portable distro or a frozen application that wants to be portable) + prefix = sys.prefix + if getattr(sys, 'frozen', None): + prefix = os.path.abspath(os.path.dirname(sys.executable)) + for reldir in ('settings', '../settings'): + localpath = os.path.abspath(os.path.join(prefix, reldir)) + if os.path.isdir(localpath): # pragma: no cover + try: + open(os.path.join(localpath, 'test.write'), 'wb').close() + os.remove(os.path.join(localpath, 'test.write')) + except IOError: + pass # We cannot write in this directory + else: + path = localpath + break + + # Get path specific for this app + if appname: + if path == userDir: + appname = '.' + appname.lstrip('.') # Make it a hidden directory + path = os.path.join(path, appname) + if not os.path.isdir(path): # pragma: no cover + os.mkdir(path) + + # Done + return path diff --git a/FaceLandmarkDetection/get_face_landmark.py b/FaceLandmarkDetection/get_face_landmark.py new file mode 100644 index 0000000000000000000000000000000000000000..1ee59679395824add92f2539a5a2bec3443a0ec9 --- /dev/null +++ b/FaceLandmarkDetection/get_face_landmark.py @@ -0,0 +1,46 @@ +#!/usr/bin/python #encoding:utf-8 +import torch +import pickle +import numpy as np +import matplotlib.pyplot as plt +from PIL import Image + +import cv2 +import os +import face_alignment +from skimage import io, transform + +fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D,device='cuda:0', flip_input=False) + + +Nums = 0 +FilePath = '../TestData/RealVgg/Imgs' +SavePath = '../TestData/RealVgg/Imgs/Landmarks' +if not os.path.exists(SavePath): + os.mkdir(SavePath) + +ImgNames = os.listdir(FilePath) +ImgNames.sort() + +for i,name in enumerate(ImgNames): + print((i,name)) + + imgs = io.imread(os.path.join(FilePath,name)) + + imgO = imgs + try: + PredsAll = fa.get_landmarks(imgO) + except: + print('#########No face') + continue + if PredsAll is None: + print('#########No face2') + continue + if len(PredsAll)!=1: + print('#########too many face') + continue + preds = PredsAll[-1] + AddLength = np.sqrt(np.sum(np.power(preds[27][0:2]-preds[33][0:2],2))) + SaveName = name+'.txt' + + np.savetxt(os.path.join(SavePath,SaveName),preds[:,0:2],fmt='%.3f') diff --git a/FaceLandmarkDetection/setup.cfg b/FaceLandmarkDetection/setup.cfg new file mode 100644 index 0000000000000000000000000000000000000000..1d5d127316734bcda878fffbbc24c56404911739 --- /dev/null +++ b/FaceLandmarkDetection/setup.cfg @@ -0,0 +1,32 @@ +[bumpversion] +current_version = 1.0.1 +commit = True +tag = True + +[bumpversion:file:setup.py] +search = version='{current_version}' +replace = version='{new_version}' + +[bumpversion:file:face_alignment/__init__.py] +search = __version__ = '{current_version}' +replace = __version__ = '{new_version}' + +[metadata] +description-file = README.md + +[bdist_wheel] +universal = 1 + +[flake8] +exclude = + .github, + examples, + docs, + .tox, + bin, + dist, + tools, + *.egg-info, + __init__.py, + *.yml +max-line-length = 160 \ No newline at end of file diff --git a/FaceLandmarkDetection/setup.py b/FaceLandmarkDetection/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..aa1a0be9bbdbc3d8e8a07e07ab11b88b5b865a30 --- /dev/null +++ b/FaceLandmarkDetection/setup.py @@ -0,0 +1,83 @@ +import io +import os +from os import path +import re +from setuptools import setup, find_packages +# To use consisten encodings +from codecs import open + +# Function from: https://github.com/pytorch/vision/blob/master/setup.py + + +def read(*names, **kwargs): + with io.open( + os.path.join(os.path.dirname(__file__), *names), + encoding=kwargs.get("encoding", "utf8") + ) as fp: + return fp.read() + +# Function from: https://github.com/pytorch/vision/blob/master/setup.py + + +def find_version(*file_paths): + version_file = read(*file_paths) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", + version_file, re.M) + if version_match: + return version_match.group(1) + raise RuntimeError("Unable to find version string.") + +here = path.abspath(path.dirname(__file__)) + +# Get the long description from the README file +with open(path.join(here, 'README.md'), encoding='utf-8') as readme_file: + long_description = readme_file.read() + +VERSION = find_version('face_alignment', '__init__.py') + +requirements = [ + 'torch', + 'numpy', + 'scipy>=0.17', + 'scikit-image', + 'opencv-python', + 'tqdm', + 'enum34;python_version<"3.4"' +] + +setup( + name='face_alignment', + version=VERSION, + + description="Detector 2D or 3D face landmarks from Python", + long_description=long_description, + long_description_content_type="text/markdown", + + # Author details + author="Adrian Bulat", + author_email="adrian.bulat@nottingham.ac.uk", + url="https://github.com/1adrianb/face-alignment", + + # Package info + packages=find_packages(exclude=('test',)), + + install_requires=requirements, + license='BSD', + zip_safe=True, + + classifiers=[ + 'Development Status :: 5 - Production/Stable', + 'Operating System :: OS Independent', + 'License :: OSI Approved :: BSD License', + 'Natural Language :: English', + + # Supported python versions + 'Programming Language :: Python :: 2', + 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.3', + 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + ], +) diff --git a/FaceLandmarkDetection/test/assets/aflw-test.jpg b/FaceLandmarkDetection/test/assets/aflw-test.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e798f64cdeda860e3b2a99e0d7e5ce8e52d7c0f4 Binary files /dev/null and b/FaceLandmarkDetection/test/assets/aflw-test.jpg differ diff --git a/FaceLandmarkDetection/test/facealignment_test.py b/FaceLandmarkDetection/test/facealignment_test.py new file mode 100644 index 0000000000000000000000000000000000000000..85637f953ab580fd8756f3476afa02723a8e12c0 --- /dev/null +++ b/FaceLandmarkDetection/test/facealignment_test.py @@ -0,0 +1,11 @@ +import unittest +import face_alignment + + +class Tester(unittest.TestCase): + def test_predict_points(self): + fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, device='cpu') + fa.get_landmarks('test/assets/aflw-test.jpg') + +if __name__ == '__main__': + unittest.main() diff --git a/FaceLandmarkDetection/test/smoke_test.py b/FaceLandmarkDetection/test/smoke_test.py new file mode 100644 index 0000000000000000000000000000000000000000..93a006d0793ac3aa1cdaedf57f45eac7b9672464 --- /dev/null +++ b/FaceLandmarkDetection/test/smoke_test.py @@ -0,0 +1,2 @@ +import torch +import face_alignment diff --git a/FaceLandmarkDetection/test/test_utils.py b/FaceLandmarkDetection/test/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..09c4d3eb0d7496d0900f3e32f14bfa8172a4252c --- /dev/null +++ b/FaceLandmarkDetection/test/test_utils.py @@ -0,0 +1,36 @@ +import unittest +from face_alignment.utils import * +import numpy as np +import torch + + +class Tester(unittest.TestCase): + def test_flip_is_label(self): + # Generate the points + heatmaps = torch.from_numpy(np.random.randint(1, high=250, size=(68, 64, 64)).astype('float32')) + + flipped_heatmaps = flip(flip(heatmaps.clone(), is_label=True), is_label=True) + + assert np.allclose(heatmaps.numpy(), flipped_heatmaps.numpy()) + + def test_flip_is_image(self): + fake_image = torch.torch.rand(3, 256, 256) + fliped_fake_image = flip(flip(fake_image.clone())) + + assert np.allclose(fake_image.numpy(), fliped_fake_image.numpy()) + + def test_getpreds(self): + pts = torch.from_numpy(np.random.randint(1, high=63, size=(68, 2)).astype('float32')) + + heatmaps = np.zeros((68, 256, 256)) + for i in range(68): + if pts[i, 0] > 0: + heatmaps[i] = draw_gaussian(heatmaps[i], pts[i], 2) + heatmaps = torch.from_numpy(np.expand_dims(heatmaps, axis=0)) + + preds, _ = get_preds_fromhm(heatmaps) + + assert np.allclose(pts.numpy(), preds.numpy(), atol=5) + +if __name__ == '__main__': + unittest.main() diff --git a/FaceLandmarkDetection/tox.ini b/FaceLandmarkDetection/tox.ini new file mode 100644 index 0000000000000000000000000000000000000000..0a862cd793d75c6005fb4dcda05edc58044f86b3 --- /dev/null +++ b/FaceLandmarkDetection/tox.ini @@ -0,0 +1,3 @@ +[flake8] +max-line-length = 120 +ignore = E305,E402,E721,F401,F403,F405,F821,F841,F999 \ No newline at end of file diff --git a/Imgs/RealLR/n000056_0060_01.png b/Imgs/RealLR/n000056_0060_01.png new file mode 100644 index 0000000000000000000000000000000000000000..005309a61a613bee4270ae4d05fb7320b1bde0be Binary files /dev/null and b/Imgs/RealLR/n000056_0060_01.png differ diff --git a/Imgs/RealLR/n000067_0228_01.png b/Imgs/RealLR/n000067_0228_01.png new file mode 100644 index 0000000000000000000000000000000000000000..c0e4f322ec3f5db1554777e233237c5cdd0fb007 Binary files /dev/null and b/Imgs/RealLR/n000067_0228_01.png differ diff --git a/Imgs/RealLR/n000184_0094_01.png b/Imgs/RealLR/n000184_0094_01.png new file mode 100644 index 0000000000000000000000000000000000000000..3e2995de5706ed71a445d3519d2c214df503b9ae Binary files /dev/null and b/Imgs/RealLR/n000184_0094_01.png differ diff --git a/Imgs/RealLR/n000241_0132_04.png b/Imgs/RealLR/n000241_0132_04.png new file mode 100644 index 0000000000000000000000000000000000000000..4f058089b0320d201f35736148e7561ad5a625d1 Binary files /dev/null and b/Imgs/RealLR/n000241_0132_04.png differ diff --git a/Imgs/RealLR/n000262_0097_01.png b/Imgs/RealLR/n000262_0097_01.png new file mode 100644 index 0000000000000000000000000000000000000000..f9aedb26d7d4c885385bf6136cf170a2c8a4f4d9 Binary files /dev/null and b/Imgs/RealLR/n000262_0097_01.png differ diff --git a/Imgs/ShowResults/n000056_0060_01.png b/Imgs/ShowResults/n000056_0060_01.png new file mode 100644 index 0000000000000000000000000000000000000000..cda40a58b3291d7c88a4930c05085604fb31f3fd Binary files /dev/null and b/Imgs/ShowResults/n000056_0060_01.png differ diff --git a/Imgs/ShowResults/n000067_0228_01.png b/Imgs/ShowResults/n000067_0228_01.png new file mode 100644 index 0000000000000000000000000000000000000000..9a4cf68d2a0d18fe8e195769d333d46d4a778f40 Binary files /dev/null and b/Imgs/ShowResults/n000067_0228_01.png differ diff --git a/Imgs/ShowResults/n000184_0094_01.png b/Imgs/ShowResults/n000184_0094_01.png new file mode 100644 index 0000000000000000000000000000000000000000..56fc0bf25806c9a9ec0c47d557636f514ace5b97 Binary files /dev/null and b/Imgs/ShowResults/n000184_0094_01.png differ diff --git a/Imgs/ShowResults/n000241_0132_04.png b/Imgs/ShowResults/n000241_0132_04.png new file mode 100644 index 0000000000000000000000000000000000000000..e35b7c0e664ffd443f000fb05bcbb1fd442c0e2e Binary files /dev/null and b/Imgs/ShowResults/n000241_0132_04.png differ diff --git a/Imgs/ShowResults/n000262_0097_01.png b/Imgs/ShowResults/n000262_0097_01.png new file mode 100644 index 0000000000000000000000000000000000000000..62590ddd60020fb4eb2dee611c87b810ebdfc6a5 Binary files /dev/null and b/Imgs/ShowResults/n000262_0097_01.png differ diff --git a/Imgs/Whole/test1_0.jpg b/Imgs/Whole/test1_0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d32a6d39ea1e3004fbfc37de79f38b008f997583 Binary files /dev/null and b/Imgs/Whole/test1_0.jpg differ diff --git a/Imgs/Whole/test1_1.jpg b/Imgs/Whole/test1_1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..72391507de1f2bd73f56928c9355e6d4b1b3b3f8 Binary files /dev/null and b/Imgs/Whole/test1_1.jpg differ diff --git a/Imgs/Whole/test1_2.jpg b/Imgs/Whole/test1_2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..647256df8b458cbf98906f211632ab050dca0e06 Binary files /dev/null and b/Imgs/Whole/test1_2.jpg differ diff --git a/Imgs/Whole/test1_3.jpg b/Imgs/Whole/test1_3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..12126d4b26a52903e4a51eb8a3acaee51228c8a1 Binary files /dev/null and b/Imgs/Whole/test1_3.jpg differ diff --git a/Imgs/Whole/test2_0.jpg b/Imgs/Whole/test2_0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8ec5fac179590e9e83656ad3a4a5cbfc9addc711 Binary files /dev/null and b/Imgs/Whole/test2_0.jpg differ diff --git a/Imgs/Whole/test2_1.jpg b/Imgs/Whole/test2_1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..24faf15ee5a326e4a21ea5dd235768e8920aac64 Binary files /dev/null and b/Imgs/Whole/test2_1.jpg differ diff --git a/Imgs/Whole/test2_2.jpg b/Imgs/Whole/test2_2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2f86fa9e4a2a0ee6b8cf748b3fd599e6fba82d80 Binary files /dev/null and b/Imgs/Whole/test2_2.jpg differ diff --git a/Imgs/Whole/test2_3.jpg b/Imgs/Whole/test2_3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fd0acbc0e03a41a2ac33e11071a4cc19fcf142d3 Binary files /dev/null and b/Imgs/Whole/test2_3.jpg differ diff --git a/Imgs/Whole/test5_0.jpg b/Imgs/Whole/test5_0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..03cd688650bd4b337d15e15c3a6ad79c48f25eb4 Binary files /dev/null and b/Imgs/Whole/test5_0.jpg differ diff --git a/Imgs/Whole/test5_1.jpg b/Imgs/Whole/test5_1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c1121f7b335e22a77d8e2473b00c8cf064ed6aaf Binary files /dev/null and b/Imgs/Whole/test5_1.jpg differ diff --git a/Imgs/Whole/test5_2.jpg b/Imgs/Whole/test5_2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..12f57ba8f70004ddd86eea8ed68815a3431ca23e Binary files /dev/null and b/Imgs/Whole/test5_2.jpg differ diff --git a/Imgs/Whole/test5_3.jpg b/Imgs/Whole/test5_3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7a39a6d41c87501a7a4ad84b55e1fdb1a3241493 Binary files /dev/null and b/Imgs/Whole/test5_3.jpg differ diff --git a/Imgs/pipeline_a.png b/Imgs/pipeline_a.png new file mode 100644 index 0000000000000000000000000000000000000000..9a62bd821495874de3545d3c31748abfb9e6a2ae Binary files /dev/null and b/Imgs/pipeline_a.png differ diff --git a/Imgs/pipeline_b.png b/Imgs/pipeline_b.png new file mode 100644 index 0000000000000000000000000000000000000000..f3170eb89e00500343738a51f608510e5a435ad4 Binary files /dev/null and b/Imgs/pipeline_b.png differ diff --git a/README.md b/README.md index 0b04060ac2511ead86987e4fbcbbfe9246b01fe8..08297cd71ad65292609b42a2970a269e65974f3e 100644 --- a/README.md +++ b/README.md @@ -1,45 +1,171 @@ ---- -title: Deep Multi Scale -emoji: 🦀 -colorFrom: green -colorTo: green -sdk: gradio -app_file: app.py -pinned: false ---- +## [Blind Face Restoration via Deep Multi-scale Component Dictionaries](https://arxiv.org/pdf/2008.00418.pdf) -# Configuration +>##### __Note: This branch contains all the restoration results, including 512×512 face region and the final result by putting the enhanced face to the origial input. The former version that can only generate the face result is put in [master branch](https://github.com/csxmli2016/DFDNet/tree/master)__ -`title`: _string_ -Display title for the Space -`emoji`: _string_ -Space emoji (emoji-only character allowed) +

+Overview of our proposed method. It mainly contains two parts: (a) the off-line generation of multi-scale component dictionaries from large amounts of high-quality images, which have diverse poses and expressions. K-means is adopted to generate K clusters for each component (i.e., left/right eyes, nose and mouth) on different feature scales. (b) The restoration process and dictionary feature transfer (DFT) block that are utilized to provide the reference details in a progressive manner. Here, DFT-i block takes the Scale-i component dictionaries for reference in the same feature level. +

+ -`colorFrom`: _string_ -Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray) + +

(a) Offline generation of multi-scale component dictionaries.

+ +

(b) Architecture of our DFDNet for dictionary feature transfer.

-`colorTo`: _string_ -Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray) -`sdk`: _string_ -Can be either `gradio`, `streamlit`, or `static` +## Pre-train Models and dictionaries +Downloading from the following url and put them into ./. +- [BaiduNetDisk](https://pan.baidu.com/s/1K4fzjPiezVSMl5NjHoJCGQ) (s9ht) +- [GoogleDrive](https://drive.google.com/drive/folders/1bayYIUMCSGmoFPyd4Uu2Uwn347RW-vl5?usp=sharing) -`sdk_version` : _string_ -Only applicable for `streamlit` SDK. -See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions. +The folder structure should be: + + . + ├── checkpoints + │ ├── facefh_dictionary + │ │ └── latest_net_G.pth + ├── weights + │ └── vgg19.pth + ├── DictionaryCenter512 + │ ├── right_eye_256_center.npy + │ ├── right_eye_128_center.npy + │ ├── right_eye_64_center.npy + │ ├── right_eye_32_center.npy + │ └── ... + └── ... -`app_file`: _string_ -Path to your main application file (which contains either `gradio` or `streamlit` Python code, or `static` html code). -Path is relative to the root of the repository. +## Prerequisites +>([Video Installation Tutorial](https://www.youtube.com/watch?v=OTqGYMSKGF4). Thanks for [bycloudump](https://www.youtube.com/channel/UCfg9ux4m8P0YDITTPptrmLg)'s tremendous help.) +- Pytorch (≥1.1 is recommended) +- dlib +- dominate +- cv2 +- tqdm +- [face-alignment](https://github.com/1adrianb/face-alignment) + ```bash + cd ./FaceLandmarkDetection + python setup.py install + cd .. + ``` + -`models`: _List[string]_ -HF model IDs (like "gpt2" or "deepset/roberta-base-squad2") used in the Space. -Will be parsed automatically from your code if not specified here. +## Testing +```bash +python test_FaceDict.py --test_path ./TestData/TestWhole --results_dir ./Results/TestWholeResults --upscale_factor 4 --gpu_ids 0 +``` +#### __Four parameters can be changed for flexible usage:__ +``` +--test_path # test image path +--results_dir # save the results path +--upscale_factor # the upsample factor for the final result +--gpu_ids # gpu id. if use cpu, set gpu_ids=-1 +``` -`datasets`: _List[string]_ -HF dataset IDs (like "common_voice" or "oscar-corpus/OSCAR-2109") used in the Space. -Will be parsed automatically from your code if not specified here. +>Note: our DFDNet can only generate 512×512 face result for any given face image. + +#### __Result path contains the following folder:__ +- Step0_Input: ```# Save the input image.``` +- Step1_AffineParam: ```# Save the crop and align parameters for copying the face result to the original input.``` +- Step1_CropImg: ```# Save the cropped face images and resize them to 512×512.``` +- Step2_Landmarks: ```# Save the facial landmarks for RoIAlign.``` +- Step3_RestoreCropFace: ```# Save the face restoration result (512×512).``` +- Step4_FinalResults: ```# Save the final restoration result by putting the enhanced face to the original input.``` + +## Some plausible restoration results on real low-quality images + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
InputCrop and AlignRestore FaceFinal Results (UpScaleWhole=4)
+ + + + + + + +
+ + + + + + + +
+ + + + + + + +
+ + + + + + + +
+ + + + + + + +
+ +## TO DO LIST (if possible) +- [ ] Enhance all the faces in one image. +- [ ] Enhance the background. + + +## Citation + +``` +@InProceedings{Li_2020_ECCV, +author = {Li, Xiaoming and Chen, Chaofeng and Zhou, Shangchen and Lin, Xianhui and Zuo, Wangmeng and Zhang, Lei}, +title = {Blind Face Restoration via Deep Multi-scale Component Dictionaries}, +booktitle = {ECCV}, +year = {2020} +} +``` + +Creative Commons License
This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License. -`pinned`: _boolean_ -Whether the Space stays on top of your list. diff --git a/TestData/TestWhole/test1.jpg b/TestData/TestWhole/test1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..dfc2edf0dd62baa5e6d254d15b20d60fe999f86a Binary files /dev/null and b/TestData/TestWhole/test1.jpg differ diff --git a/TestData/TestWhole/test2.jpg b/TestData/TestWhole/test2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6cb36b9ee462c1e86f90e88cfaaaa5321512a42d Binary files /dev/null and b/TestData/TestWhole/test2.jpg differ diff --git a/TestData/TestWhole/test3.jpg b/TestData/TestWhole/test3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3f829f02de8f8ec16fe4488f5c00fd64c43fbf07 Binary files /dev/null and b/TestData/TestWhole/test3.jpg differ diff --git a/TestData/TestWhole/test4.jpg b/TestData/TestWhole/test4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..84d1182cd011482530fffe35e8b875f0c8608eb3 Binary files /dev/null and b/TestData/TestWhole/test4.jpg differ diff --git a/TestData/TestWhole/test5.jpg b/TestData/TestWhole/test5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1955cae069630b0b37124a5e8719faa2c4383063 Binary files /dev/null and b/TestData/TestWhole/test5.jpg differ diff --git a/TestData/TestWhole/test6.png b/TestData/TestWhole/test6.png new file mode 100644 index 0000000000000000000000000000000000000000..0e3c871d4599d8a06a12a60417da879c038af776 Binary files /dev/null and b/TestData/TestWhole/test6.png differ diff --git a/data/MotionBlurKernel/m_01.mat b/data/MotionBlurKernel/m_01.mat new file mode 100644 index 0000000000000000000000000000000000000000..e25bac95b9723201e4522bb5c4ae63bafc0a2cec Binary files /dev/null and b/data/MotionBlurKernel/m_01.mat differ diff --git a/data/MotionBlurKernel/m_02.mat b/data/MotionBlurKernel/m_02.mat new file mode 100644 index 0000000000000000000000000000000000000000..de05fde3f92121df1e1d8d7fc0ed3bd9575c2c90 Binary files /dev/null and b/data/MotionBlurKernel/m_02.mat differ diff --git a/data/MotionBlurKernel/m_03.mat b/data/MotionBlurKernel/m_03.mat new file mode 100644 index 0000000000000000000000000000000000000000..98a9aaf3701af01d2e16b3e6d9b0760a55eb5abf Binary files /dev/null and b/data/MotionBlurKernel/m_03.mat differ diff --git a/data/MotionBlurKernel/m_04.mat b/data/MotionBlurKernel/m_04.mat new file mode 100644 index 0000000000000000000000000000000000000000..9fc10561b7b3c6493a014c1e803798e2b395972f Binary files /dev/null and b/data/MotionBlurKernel/m_04.mat differ diff --git a/data/MotionBlurKernel/m_05.mat b/data/MotionBlurKernel/m_05.mat new file mode 100644 index 0000000000000000000000000000000000000000..3989860435391f9d0dafc94567a02d60bcfc957b Binary files /dev/null and b/data/MotionBlurKernel/m_05.mat differ diff --git a/data/MotionBlurKernel/m_06.mat b/data/MotionBlurKernel/m_06.mat new file mode 100644 index 0000000000000000000000000000000000000000..72d69530a288c3fd3779821c3162a1011d9dd795 Binary files /dev/null and b/data/MotionBlurKernel/m_06.mat differ diff --git a/data/MotionBlurKernel/m_07.mat b/data/MotionBlurKernel/m_07.mat new file mode 100644 index 0000000000000000000000000000000000000000..afcf374bdb69b8f29f0f63e23818031f7ba718e1 Binary files /dev/null and b/data/MotionBlurKernel/m_07.mat differ diff --git a/data/MotionBlurKernel/m_08.mat b/data/MotionBlurKernel/m_08.mat new file mode 100644 index 0000000000000000000000000000000000000000..97fdc99e8695c9a24c3c38053fa812806465e68c Binary files /dev/null and b/data/MotionBlurKernel/m_08.mat differ diff --git a/data/MotionBlurKernel/m_09.mat b/data/MotionBlurKernel/m_09.mat new file mode 100644 index 0000000000000000000000000000000000000000..2b580721efe97c8c2048f6a763f0e2d9a73bd90c Binary files /dev/null and b/data/MotionBlurKernel/m_09.mat differ diff --git a/data/MotionBlurKernel/m_10.mat b/data/MotionBlurKernel/m_10.mat new file mode 100644 index 0000000000000000000000000000000000000000..fbf1e9b9e900e5575c80ebbf9a5e44c9866d7899 Binary files /dev/null and b/data/MotionBlurKernel/m_10.mat differ diff --git a/data/MotionBlurKernel/m_11.mat b/data/MotionBlurKernel/m_11.mat new file mode 100644 index 0000000000000000000000000000000000000000..13073988562afadec1b06d7e2aa5a149931e4270 Binary files /dev/null and b/data/MotionBlurKernel/m_11.mat differ diff --git a/data/MotionBlurKernel/m_12.mat b/data/MotionBlurKernel/m_12.mat new file mode 100644 index 0000000000000000000000000000000000000000..8238a8db58ccdcfc5a767ce864ae9fca2148d702 Binary files /dev/null and b/data/MotionBlurKernel/m_12.mat differ diff --git a/data/MotionBlurKernel/m_13.mat b/data/MotionBlurKernel/m_13.mat new file mode 100644 index 0000000000000000000000000000000000000000..bd9f6bdebf79d8a1c201833b938cc54f143f0111 Binary files /dev/null and b/data/MotionBlurKernel/m_13.mat differ diff --git a/data/MotionBlurKernel/m_14.mat b/data/MotionBlurKernel/m_14.mat new file mode 100644 index 0000000000000000000000000000000000000000..cde1d24dc67470cde9d79c2741f290395dfe27f3 Binary files /dev/null and b/data/MotionBlurKernel/m_14.mat differ diff --git a/data/MotionBlurKernel/m_15.mat b/data/MotionBlurKernel/m_15.mat new file mode 100644 index 0000000000000000000000000000000000000000..3538ce4a1d6c78344ddff7325d4e9ce878fa464e Binary files /dev/null and b/data/MotionBlurKernel/m_15.mat differ diff --git a/data/MotionBlurKernel/m_16.mat b/data/MotionBlurKernel/m_16.mat new file mode 100644 index 0000000000000000000000000000000000000000..3e3f9c6d5f6ec1d3a687a706779c7c42405b84be Binary files /dev/null and b/data/MotionBlurKernel/m_16.mat differ diff --git a/data/MotionBlurKernel/m_17.mat b/data/MotionBlurKernel/m_17.mat new file mode 100644 index 0000000000000000000000000000000000000000..eb9b331268f6671b0f9d5a3920ccc7eca89413b3 Binary files /dev/null and b/data/MotionBlurKernel/m_17.mat differ diff --git a/data/MotionBlurKernel/m_18.mat b/data/MotionBlurKernel/m_18.mat new file mode 100644 index 0000000000000000000000000000000000000000..7412fce0daf0ed74af4877c13e5960f5e3174468 Binary files /dev/null and b/data/MotionBlurKernel/m_18.mat differ diff --git a/data/MotionBlurKernel/m_19.mat b/data/MotionBlurKernel/m_19.mat new file mode 100644 index 0000000000000000000000000000000000000000..2c06d2aac4a33fcebaebc2d8b1c65a34e07ffb14 Binary files /dev/null and b/data/MotionBlurKernel/m_19.mat differ diff --git a/data/MotionBlurKernel/m_20.mat b/data/MotionBlurKernel/m_20.mat new file mode 100644 index 0000000000000000000000000000000000000000..70385f34b2b0605cca198ba14a3276adac2976af Binary files /dev/null and b/data/MotionBlurKernel/m_20.mat differ diff --git a/data/MotionBlurKernel/m_21.mat b/data/MotionBlurKernel/m_21.mat new file mode 100644 index 0000000000000000000000000000000000000000..773d50937455be6714f324ab7c7967bc02287aaf Binary files /dev/null and b/data/MotionBlurKernel/m_21.mat differ diff --git a/data/MotionBlurKernel/m_22.mat b/data/MotionBlurKernel/m_22.mat new file mode 100644 index 0000000000000000000000000000000000000000..c051c507f9e46efe768e4308a9b555b2800c28f9 Binary files /dev/null and b/data/MotionBlurKernel/m_22.mat differ diff --git a/data/MotionBlurKernel/m_23.mat b/data/MotionBlurKernel/m_23.mat new file mode 100644 index 0000000000000000000000000000000000000000..48303ad2d9f4ef38ded3a52fd0ce1fd2f39651e3 Binary files /dev/null and b/data/MotionBlurKernel/m_23.mat differ diff --git a/data/MotionBlurKernel/m_24.mat b/data/MotionBlurKernel/m_24.mat new file mode 100644 index 0000000000000000000000000000000000000000..129f770db08b5fe1dea0c2335a7d466183935383 Binary files /dev/null and b/data/MotionBlurKernel/m_24.mat differ diff --git a/data/MotionBlurKernel/m_25.mat b/data/MotionBlurKernel/m_25.mat new file mode 100644 index 0000000000000000000000000000000000000000..42f2aa675a5e7660a5a3b48846a710839c03402b Binary files /dev/null and b/data/MotionBlurKernel/m_25.mat differ diff --git a/data/MotionBlurKernel/m_26.mat b/data/MotionBlurKernel/m_26.mat new file mode 100644 index 0000000000000000000000000000000000000000..480f0694f27dc46970022a289560fa36a421e73b Binary files /dev/null and b/data/MotionBlurKernel/m_26.mat differ diff --git a/data/MotionBlurKernel/m_27.mat b/data/MotionBlurKernel/m_27.mat new file mode 100644 index 0000000000000000000000000000000000000000..7e9cbb553b6775bdd8f66c7868ce1f6b345c7ec8 Binary files /dev/null and b/data/MotionBlurKernel/m_27.mat differ diff --git a/data/MotionBlurKernel/m_28.mat b/data/MotionBlurKernel/m_28.mat new file mode 100644 index 0000000000000000000000000000000000000000..03d3aa5eb2c02bf50435c036beb0cc0e95a7febe Binary files /dev/null and b/data/MotionBlurKernel/m_28.mat differ diff --git a/data/MotionBlurKernel/m_29.mat b/data/MotionBlurKernel/m_29.mat new file mode 100644 index 0000000000000000000000000000000000000000..e238ed66b930a38b0116ebba8ce913403d7ea086 Binary files /dev/null and b/data/MotionBlurKernel/m_29.mat differ diff --git a/data/MotionBlurKernel/m_30.mat b/data/MotionBlurKernel/m_30.mat new file mode 100644 index 0000000000000000000000000000000000000000..700e4c230c199f1c153e634fa6574dd158307124 Binary files /dev/null and b/data/MotionBlurKernel/m_30.mat differ diff --git a/data/MotionBlurKernel/m_31.mat b/data/MotionBlurKernel/m_31.mat new file mode 100644 index 0000000000000000000000000000000000000000..f335141069b0797699bfe758dddc96a1da29e295 Binary files /dev/null and b/data/MotionBlurKernel/m_31.mat differ diff --git a/data/MotionBlurKernel/m_32.mat b/data/MotionBlurKernel/m_32.mat new file mode 100644 index 0000000000000000000000000000000000000000..5acf9b436bcc1cf1f7d1980cff0893ee392e6548 Binary files /dev/null and b/data/MotionBlurKernel/m_32.mat differ diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2193cc788f4a263e13cf4890ff7c2231ce951776 --- /dev/null +++ b/data/__init__.py @@ -0,0 +1,80 @@ +# -- coding: utf-8 -- +import importlib +import torch.utils.data +from data.base_data_loader import BaseDataLoader +from data.base_dataset import BaseDataset + + +def find_dataset_using_name(dataset_name): + + # Given the option --dataset_mode [datasetname], + # the file "data/datasetname_dataset.py" + # will be imported. + dataset_filename = "data." + dataset_name + "_dataset" + datasetlib = importlib.import_module(dataset_filename) + + # In the file, the class called DatasetNameDataset() will + # be instantiated. It has to be a subclass of BaseDataset, + # and it is case-insensitive. + dataset = None + target_dataset_name = dataset_name.replace('_', '') + 'dataset' + for name, cls in datasetlib.__dict__.items(): + if name.lower() == target_dataset_name.lower() \ + and issubclass(cls, BaseDataset): + dataset = cls + + if dataset is None: + print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) + exit(0) + + return dataset + + +def get_option_setter(dataset_name): + dataset_class = find_dataset_using_name(dataset_name) + return dataset_class.modify_commandline_options + + +def create_dataset(opt): + dataset = find_dataset_using_name(opt.dataset_mode) + instance = dataset() + instance.initialize(opt) + print("dataset [%s] was created" % (instance.name())) + return instance + + +def CreateDataLoader(opt): + data_loader = CustomDatasetDataLoader() + data_loader.initialize(opt) + return data_loader + + +# Wrapper class of Dataset class that performs +# multi-threaded data loading +class CustomDatasetDataLoader(BaseDataLoader): + def name(self): + return 'CustomDatasetDataLoader' + + def initialize(self, opt): + BaseDataLoader.initialize(self, opt) + self.dataset = create_dataset(opt) + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=opt.batchSize, + shuffle=not opt.serial_batches, + num_workers=int(opt.nThreads)) + # DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False) + # 加载的数据集,DataSet对象,是否打乱,样本抽样,使用多线程加载的进程数,0表示不使用多线程,如何将多样本数据拼接成一个batch,是否将数据保存到pin memory,dataset种数据可数可能不是\ + # 一个batch_size的整数倍,drop_last 为True将多出来不足一个batch的数据丢弃 + + def load_data(self): + return self + + def __len__(self): + return min(len(self.dataset), self.opt.max_dataset_size) + + def __iter__(self): + for i, data in enumerate(self.dataloader): + if i * self.opt.batchSize >= self.opt.max_dataset_size: + break + yield data diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3eabcca633ec86e152f6fec4996aea9f541a6d7e --- /dev/null +++ b/data/aligned_dataset.py @@ -0,0 +1,137 @@ +# -- coding: utf-8 -- +import os.path +import random +import torchvision.transforms as transforms +import torch +from data.base_dataset import BaseDataset +from data.image_folder import make_dataset +from PIL import Image, ImageFilter +import numpy as np +import cv2 +import math +from util import util +from scipy.io import loadmat +from PIL import Image +import PIL + + +class AlignedDataset(BaseDataset): + @staticmethod + def modify_commandline_options(parser, is_train): + return parser + + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + self.partpath = opt.partroot + self.dir_AB = os.path.join(opt.dataroot, opt.phase) + self.AB_paths = sorted(make_dataset(self.dir_AB)) + self.is_real = opt.is_real + # assert(opt.resize_or_crop == 'resize_and_crop') + assert(opt.resize_or_crop == 'degradation') + + def AddNoise(self,img): # noise + if random.random() > 0.9: # + return img + self.sigma = np.random.randint(1, 11) + img_tensor = torch.from_numpy(np.array(img)).float() + noise = torch.randn(img_tensor.size()).mul_(self.sigma/1.0) + + noiseimg = torch.clamp(noise+img_tensor,0,255) + return Image.fromarray(np.uint8(noiseimg.numpy())) + + def AddBlur(self,img): # gaussian blur or motion blur + if random.random() > 0.9: # + return img + img = np.array(img) + if random.random() > 0.35: ##gaussian blur + blursize = random.randint(1,17) * 2 + 1 ##3,5,7,9,11,13,15 + blursigma = random.randint(3, 20) + img = cv2.GaussianBlur(img, (blursize,blursize), blursigma/10) + else: #motion blur + M = random.randint(1,32) + KName = './data/MotionBlurKernel/m_%02d.mat' % M + k = loadmat(KName)['kernel'] + k = k.astype(np.float32) + k /= np.sum(k) + img = cv2.filter2D(img,-1,k) + return Image.fromarray(img) + + def AddDownSample(self,img): # downsampling + if random.random() > 0.95: # + return img + sampler = random.randint(20, 100)*1.0 + img = img.resize((int(self.opt.fineSize/sampler*10.0), int(self.opt.fineSize/sampler*10.0)), Image.BICUBIC) + return img + + def AddJPEG(self,img): # JPEG compression + if random.random() > 0.6: # + return img + imQ = random.randint(40, 80) + img = np.array(img) + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY),imQ] # (0,100),higher is better,default is 95 + _, encA = cv2.imencode('.jpg',img,encode_param) + img = cv2.imdecode(encA,1) + return Image.fromarray(img) + + def AddUpSample(self,img): + return img.resize((self.opt.fineSize, self.opt.fineSize), Image.BICUBIC) + + def __getitem__(self, index): # + + AB_path = self.AB_paths[index] + Imgs = Image.open(AB_path).convert('RGB') + # # + A = Imgs.resize((self.opt.fineSize, self.opt.fineSize)) + A = transforms.ColorJitter(0.3, 0.3, 0.3, 0)(A) + C = A + A = self.AddUpSample(self.AddJPEG(self.AddNoise(self.AddDownSample(self.AddBlur(A))))) + + tmps = AB_path.split('/') + ImgName = tmps[-1] + Part_locations = self.get_part_location(self.partpath, ImgName, 2) + + A = transforms.ToTensor()(A) # + C = transforms.ToTensor()(C) + + ## + A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A) # + C = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(C) # + return {'A':A, 'C':C, 'A_paths': AB_path,'Part_locations': Part_locations} + + def get_part_location(self, landmarkpath, imgname, downscale=1): + Landmarks = [] + with open(os.path.join(landmarkpath,imgname+'.txt'),'r') as f: + for line in f: + tmp = [np.float(i) for i in line.split(' ') if i != '\n'] + Landmarks.append(tmp) + Landmarks = np.array(Landmarks)/downscale # 512 * 512 + + Map_LE = list(np.hstack((range(17,22), range(36,42)))) + Map_RE = list(np.hstack((range(22,27), range(42,48)))) + Map_NO = list(range(29,36)) + Map_MO = list(range(48,68)) + #left eye + Mean_LE = np.mean(Landmarks[Map_LE],0) + L_LE = np.max((np.max(np.max(Landmarks[Map_LE],0) - np.min(Landmarks[Map_LE],0))/2,16)) + Location_LE = np.hstack((Mean_LE - L_LE + 1, Mean_LE + L_LE)).astype(int) + #right eye + Mean_RE = np.mean(Landmarks[Map_RE],0) + L_RE = np.max((np.max(np.max(Landmarks[Map_RE],0) - np.min(Landmarks[Map_RE],0))/2,16)) + Location_RE = np.hstack((Mean_RE - L_RE + 1, Mean_RE + L_RE)).astype(int) + #nose + Mean_NO = np.mean(Landmarks[Map_NO],0) + L_NO = np.max((np.max(np.max(Landmarks[Map_NO],0) - np.min(Landmarks[Map_NO],0))/2,16)) + Location_NO = np.hstack((Mean_NO - L_NO + 1, Mean_NO + L_NO)).astype(int) + #mouth + Mean_MO = np.mean(Landmarks[Map_MO],0) + L_MO = np.max((np.max(np.max(Landmarks[Map_MO],0) - np.min(Landmarks[Map_MO],0))/2,16)) + + Location_MO = np.hstack((Mean_MO - L_MO + 1, Mean_MO + L_MO)).astype(int) + return Location_LE, Location_RE, Location_NO, Location_MO + + def __len__(self): # + return len(self.AB_paths) + + def name(self): + return 'AlignedDataset' diff --git a/data/base_data_loader.py b/data/base_data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..5abf22edfd9baf7e25d281d1f7bfe714b6b9294e --- /dev/null +++ b/data/base_data_loader.py @@ -0,0 +1,10 @@ +class BaseDataLoader(): + def __init__(self): + pass + + def initialize(self, opt): + self.opt = opt + pass + + def load_data(self): + return None diff --git a/data/base_dataset.py b/data/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..de1cee3e7296ac87a6470ab91af70da4ca6a2fe7 --- /dev/null +++ b/data/base_dataset.py @@ -0,0 +1,104 @@ +import torch.utils.data as data +from PIL import Image +import torchvision.transforms as transforms + + +class BaseDataset(data.Dataset): + def __init__(self): + super(BaseDataset, self).__init__() + + def name(self): + return 'BaseDataset' + + @staticmethod + def modify_commandline_options(parser, is_train): + return parser + + def initialize(self, opt): + pass + + def __len__(self): + return 0 + + +def get_transform(opt): + transform_list = [] + if opt.resize_or_crop == 'resize_and_crop': + + osize = [opt.loadSize, opt.loadSize] + transform_list.append(transforms.Resize(osize, Image.BICUBIC)) + transform_list.append(transforms.RandomCrop(opt.fineSize)) + elif opt.resize_or_crop == 'crop': + transform_list.append(transforms.RandomCrop(opt.fineSize)) + elif opt.resize_or_crop == 'scale_width': + transform_list.append(transforms.Lambda( + lambda img: __scale_width(img, opt.fineSize))) + elif opt.resize_or_crop == 'scale_width_and_crop': + transform_list.append(transforms.Lambda( + lambda img: __scale_width(img, opt.loadSize))) + transform_list.append(transforms.RandomCrop(opt.fineSize)) + elif opt.resize_or_crop == 'none': + transform_list.append(transforms.Lambda( + lambda img: __adjust(img))) + else: + raise ValueError('--resize_or_crop %s is not a valid option.' % opt.resize_or_crop) + + if opt.isTrain and not opt.no_flip: + transform_list.append(transforms.RandomHorizontalFlip()) + + transform_list += [transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + return transforms.Compose(transform_list) + +# just modify the width and height to be multiple of 4 +def __adjust(img): + ow, oh = img.size + + # the size needs to be a multiple of this number, + # because going through generator network may change img size + # and eventually cause size mismatch error + mult = 4 + if ow % mult == 0 and oh % mult == 0: + return img + w = (ow - 1) // mult + w = (w + 1) * mult + h = (oh - 1) // mult + h = (h + 1) * mult + + if ow != w or oh != h: + __print_size_warning(ow, oh, w, h) + + return img.resize((w, h), Image.BICUBIC) + + +def __scale_width(img, target_width): + ow, oh = img.size + + # the size needs to be a multiple of this number, + # because going through generator network may change img size + # and eventually cause size mismatch error + mult = 4 + assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult + if (ow == target_width and oh % mult == 0): + return img + w = target_width + target_height = int(target_width * oh / ow) + m = (target_height - 1) // mult + h = (m + 1) * mult + + if target_height != h: + __print_size_warning(target_width, target_height, w, h) + + return img.resize((w, h), Image.BICUBIC) + + +def __print_size_warning(ow, oh, w, h): + if not hasattr(__print_size_warning, 'has_printed'): + print("The image size needs to be a multiple of 4. " + "The loaded image size was (%d, %d), so it was adjusted to " + "(%d, %d). This adjustment will be done to all images " + "whose sizes are not multiples of 4" % (ow, oh, w, h)) + __print_size_warning.has_printed = True + + diff --git a/data/image_folder.py b/data/image_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..23df4c8bc6b692b84d26ae85d81e8a5ac3c5aad4 --- /dev/null +++ b/data/image_folder.py @@ -0,0 +1,69 @@ +############################################################################### +# Code from +# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py +# Modified the original code so that it also loads images from the current +# directory as well as the subdirectories +############################################################################### + +import torch.utils.data as data + +from PIL import Image +import os +import os.path + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def make_dataset(dirs): + images = [] + assert os.path.isdir(dirs), '%s is not a valid directory' % dirs + + for root, _, fnames in sorted(os.walk(dirs)): + fnames.sort() + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + images.append(path) + + return images + + +def default_loader(path): + return Image.open(path).convert('RGB') + + +class ImageFolder(data.Dataset): + + def __init__(self, root, transform=None, return_paths=False, + loader=default_loader): + imgs = make_dataset(root) + if len(imgs) == 0: + raise(RuntimeError("Found 0 images in: " + root + "\n" + "Supported image extensions are: " + + ",".join(IMG_EXTENSIONS))) + + self.root = root + self.imgs = imgs + self.transform = transform + self.return_paths = return_paths + self.loader = loader + + def __getitem__(self, index): + path = self.imgs[index] + img = self.loader(path) + if self.transform is not None: + img = self.transform(img) + if self.return_paths: + return img, path + else: + return img + + def __len__(self): + return len(self.imgs) diff --git a/data/single_dataset.py b/data/single_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c9f515b024b3c91aedfea824f075cf3a81c4376e --- /dev/null +++ b/data/single_dataset.py @@ -0,0 +1,42 @@ +import os.path +from data.base_dataset import BaseDataset, get_transform +from data.image_folder import make_dataset +from PIL import Image + + +class SingleDataset(BaseDataset): + @staticmethod + def modify_commandline_options(parser, is_train): + return parser + + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + self.dir_A = os.path.join(opt.dataroot) + + self.A_paths = make_dataset(self.dir_A) + + self.A_paths = sorted(self.A_paths) + + self.transform = get_transform(opt) + + def __getitem__(self, index): + A_path = self.A_paths[index] + A_img = Image.open(A_path).convert('RGB') + A = self.transform(A_img) + if self.opt.which_direction == 'BtoA': + input_nc = self.opt.output_nc + else: + input_nc = self.opt.input_nc + + if input_nc == 1: # RGB to gray + tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 + A = tmp.unsqueeze(0) + + return {'A': A, 'A_paths': A_path} + + def __len__(self): + return len(self.A_paths) + + def name(self): + return 'SingleImageDataset' diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..06938b7f777230802305570e655af31be6121e03 --- /dev/null +++ b/data/unaligned_dataset.py @@ -0,0 +1,62 @@ +import os.path +from data.base_dataset import BaseDataset, get_transform +from data.image_folder import make_dataset +from PIL import Image +import random + + +class UnalignedDataset(BaseDataset): + @staticmethod + def modify_commandline_options(parser, is_train): + return parser + + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') + self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') + + self.A_paths = make_dataset(self.dir_A) + self.B_paths = make_dataset(self.dir_B) + + self.A_paths = sorted(self.A_paths) + self.B_paths = sorted(self.B_paths) + self.A_size = len(self.A_paths) + self.B_size = len(self.B_paths) + self.transform = get_transform(opt) + + def __getitem__(self, index): + A_path = self.A_paths[index % self.A_size] + if self.opt.serial_batches: + index_B = index % self.B_size + else: + index_B = random.randint(0, self.B_size - 1) + B_path = self.B_paths[index_B] + # print('(A, B) = (%d, %d)' % (index_A, index_B)) + A_img = Image.open(A_path).convert('RGB') + B_img = Image.open(B_path).convert('RGB') + + A = self.transform(A_img) + B = self.transform(B_img) + if self.opt.which_direction == 'BtoA': + input_nc = self.opt.output_nc + output_nc = self.opt.input_nc + else: + input_nc = self.opt.input_nc + output_nc = self.opt.output_nc + + if input_nc == 1: # RGB to gray + tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 + A = tmp.unsqueeze(0) + + if output_nc == 1: # RGB to gray + tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114 + B = tmp.unsqueeze(0) + return {'A': A, 'B': B, + 'A_paths': A_path, 'B_paths': B_path} + + def __len__(self): + return max(self.A_size, self.B_size) + + def name(self): + return 'UnalignedDataset' diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..259f43fb224809862d90ed0c8b9d234bfa809540 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,37 @@ +import importlib +from models.base_model import BaseModel + + +def find_model_using_name(model_name): + # Given the option --model [modelname], + # the file "models/modelname_model.py" + # will be imported. + model_filename = "models." + model_name + "_model" + modellib = importlib.import_module(model_filename) + + # In the file, the class called ModelNameModel() will + # be instantiated. It has to be a subclass of BaseModel, + # and it is case-insensitive. + model = None + target_model_name = model_name.replace('_', '') + 'model' + + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower() \ + and issubclass(cls, BaseModel): + model = cls + if model is None: + print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) + + return model + + +def get_option_setter(model_name): + model_class = find_model_using_name(model_name) + return model_class.modify_commandline_options + + +def create_model(opt): + model = find_model_using_name(opt.model) + instance = model() + instance.initialize(opt) + return instance diff --git a/models/__pycache__/__init__.cpython-38.pyc b/models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb0838fd62bb126983c6b1f5648270af07733b8c Binary files /dev/null and b/models/__pycache__/__init__.cpython-38.pyc differ diff --git a/models/__pycache__/base_model.cpython-38.pyc b/models/__pycache__/base_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fe02be4b617dc504e316e3aa82b615529c8ba03 Binary files /dev/null and b/models/__pycache__/base_model.cpython-38.pyc differ diff --git a/models/__pycache__/networks.cpython-38.pyc b/models/__pycache__/networks.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df2afa669b6ddfd3d9fa66569c3ad064e3c576d2 Binary files /dev/null and b/models/__pycache__/networks.cpython-38.pyc differ diff --git a/models/__pycache__/test_model.cpython-38.pyc b/models/__pycache__/test_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cb31788851358cd78aa016d2af8e4c7be893fd1 Binary files /dev/null and b/models/__pycache__/test_model.cpython-38.pyc differ diff --git a/models/base_model.py b/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ef1e2f7d81fd6d4f74547d54a0a63335ceab2cf6 --- /dev/null +++ b/models/base_model.py @@ -0,0 +1,173 @@ +import os +import torch +from collections import OrderedDict +from . import networks + + +class BaseModel(): + + # modify parser to add command line options, + # and also change the default values if needed + @staticmethod + def modify_commandline_options(parser, is_train): + return parser + + def name(self): + return 'BaseModel' + + def initialize(self, opt): + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) + if opt.resize_or_crop != 'scale_width': + torch.backends.cudnn.benchmark = True + self.loss_names = [] + self.model_names = [] + self.visual_names = [] + self.image_paths = [] + # self.optimizers = [] + + def set_input(self, input): + self.input = input + + def forward(self): + pass + + # load and print networks; create schedulers + def setup(self, opt, parser=None): + if self.isTrain: + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + if not self.isTrain or opt.continue_train: + self.load_networks(opt.which_epoch) + self.print_networks(opt.verbose) + + # make models eval mode during test time + def eval(self): + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + net.eval() + + # used in test time, wrapping `forward` in no_grad() so we don't save + # intermediate steps for backprop + def test(self): + with torch.no_grad(): + self.forward() + + # get image paths + def get_image_paths(self): + return self.image_paths + + def optimize_parameters(self): + pass + + # update learning rate (called once every epoch) + def update_learning_rate(self): + for scheduler in self.schedulers: + scheduler.step() + lr = self.optimizers[0].param_groups[0]['lr'] + print('learning rate = %.7f' % lr) + + # return visualization images. train.py will display these images, and save the images to a html + def get_current_visuals(self): + visual_ret = OrderedDict() + for name in self.visual_names: + if isinstance(name, str): + visual_ret[name] = getattr(self, name) + return visual_ret + + # return traning losses/errors. train.py will print out these errors as debugging information + def get_current_losses(self): + errors_ret = OrderedDict() + for name in self.loss_names: + if isinstance(name, str): + # float(...) works for both scalar tensor and float number + errors_ret[name] = float(getattr(self, 'loss_' + name)) + return errors_ret + + # save models to the disk + def save_networks(self, which_epoch): + for name in self.model_names: + if isinstance(name, str): + save_filename = '%s_net_%s.pth' % (which_epoch, name) + save_path = os.path.join(self.save_dir, save_filename) + net = getattr(self, 'net' + name) + + if len(self.gpu_ids) > 0 and torch.cuda.is_available(): + torch.save(net.module.cpu().state_dict(), save_path) + net.cuda(self.gpu_ids[0]) + else: + torch.save(net.cpu().state_dict(), save_path) + + def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): + key = keys[i] + if i + 1 == len(keys): # at the end, pointing to a parameter/buffer + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'running_mean' or key == 'running_var'): + if getattr(module, key) is None: + state_dict.pop('.'.join(keys)) + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'num_batches_tracked'): + state_dict.pop('.'.join(keys)) + else: + self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) + + # load models from the disk + def load_networks(self, which_epoch): + for name in self.model_names: + if isinstance(name, str): + load_filename = '%s_net_%s.pth' % (which_epoch, name) + load_path = os.path.join(self.save_dir, load_filename) + net = getattr(self, 'net' + name) + if isinstance(net, torch.nn.DataParallel): + net = net.module + # print('loading the model from %s' % load_path) + # if you are using PyTorch newer than 0.4 (e.g., built from + # GitHub source), you can remove str() on self.device + if not os.path.exists(load_path): + continue + state_dict = torch.load(load_path, map_location=str(self.device)) + if hasattr(state_dict, '_metadata'): + del state_dict._metadata + + # patch InstanceNorm checkpoints prior to 0.4 + # for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop + # self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) + model_dict = net.state_dict() + # new_dict = {k: v for k, v in state_dict.items() if k in model_dict.keys()} + new_dict = {} + for k, v in state_dict.items(): + if k in model_dict.keys(): + # print(k) + # if k == 'sff_branch.0.sff0.MaskModel.0.weight' or k =='sff_branch.0.sff1.MaskModel.0.weight' or k == 'sff_branch.1.sff0.MaskModel.0.weight' or k =='sff_branch.1.sff1.MaskModel.0.weight' or k == 'sff_branch.2.sff0.MaskModel.0.weight' or k =='sff_branch.2.sff1.MaskModel.0.weight' or k == 'sff_branch.3.sff0.MaskModel.0.weight' or k =='sff_branch.3.sff1.MaskModel.0.weight' or k == 'sff_branch.4.MaskModel.0.weight' : + # continue + # if 'Mask_CModel.model' in k: + # continue + new_dict[k] = v + model_dict.update(new_dict) + net.load_state_dict(model_dict) + + # print network information + def print_networks(self, verbose): + # print('---------- Networks initialized -------------') + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + # if verbose: + # print(net) + # print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) + # print('-----------------------------------------------') + + # set requies_grad=Fasle to avoid computation + def set_requires_grad(self, nets, requires_grad=False): + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad diff --git a/models/networks.py b/models/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..761bfdd2f20b1ede959bad67783b3d1ca7415b9e --- /dev/null +++ b/models/networks.py @@ -0,0 +1,652 @@ +import torch +import torch.nn as nn +from torch.nn import init +import functools +from torch.optim import lr_scheduler +import torch.nn.functional as F +from torch.nn import Parameter as P +from util import util +from torchvision import models +import scipy.io as sio +import numpy as np +import scipy.ndimage +import torch.nn.utils.spectral_norm as SpectralNorm + +from torch.autograd import Function +from math import sqrt +import random +import os +import math + +from sync_batchnorm import convert_model +#### + +############################################################################### +# Helper Functions +############################################################################### +def get_norm_layer(norm_type='instance'): + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True) + elif norm_type == 'none': + norm_layer = None + else: + raise NotImplementedError('normalization layer [%s] is not found' % norm_type) + + return norm_layer + + +def get_scheduler(optimizer, opt): + if opt.lr_policy == 'lambda': + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) + return lr_l + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) + + return scheduler + + +def init_weights(net, init_type='normal', gain=0.02): + def init_func(m): + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: + init.normal_(m.weight.data, 1.0, gain) + init.constant_(m.bias.data, 0.0) + + print('initialize network with %s' % init_type) + net.apply(init_func) + + +def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], init_flag=True): + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + net = convert_model(net) + net.to(gpu_ids[0]) + net = torch.nn.DataParallel(net, gpu_ids) + + if init_flag: + + init_weights(net, init_type, gain=init_gain) + + return net + + +# compute adaptive instance norm +def calc_mean_std(feat, eps=1e-5): + # eps is a small value added to the variance to avoid divide-by-zero. + size = feat.size() + assert (len(size) == 3) + C, _ = size[:2] + feat_var = feat.contiguous().view(C, -1).var(dim=1) + eps + feat_std = feat_var.sqrt().view(C, 1, 1) + feat_mean = feat.contiguous().view(C, -1).mean(dim=1).view(C, 1, 1) + + return feat_mean, feat_std + + +def adaptive_instance_normalization(content_feat, style_feat): # content_feat is degraded feature, style is ref feature + assert (content_feat.size()[:1] == style_feat.size()[:1]) + size = content_feat.size() + style_mean, style_std = calc_mean_std(style_feat) + content_mean, content_std = calc_mean_std(content_feat) + + normalized_feat = (content_feat - content_mean.expand( + size)) / content_std.expand(size) + + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + +def calc_mean_std_4D(feat, eps=1e-5): + # eps is a small value added to the variance to avoid divide-by-zero. + size = feat.size() + assert (len(size) == 4) + N, C = size[:2] + feat_var = feat.view(N, C, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().view(N, C, 1, 1) + feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) + return feat_mean, feat_std + +def adaptive_instance_normalization_4D(content_feat, style_feat): # content_feat is ref feature, style is degradate feature + # assert (content_feat.size()[:2] == style_feat.size()[:2]) + size = content_feat.size() + style_mean, style_std = calc_mean_std_4D(style_feat) + + content_mean, content_std = calc_mean_std_4D(content_feat) + normalized_feat = (content_feat - content_mean.expand( + size)) / content_std.expand(size) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + +def define_G(which_model_netG, gpu_ids=[]): + if which_model_netG == 'UNetDictFace': + netG = UNetDictFace(64) + init_flag = False + else: + raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) + return init_net(netG, 'normal', 0.02, gpu_ids, init_flag) + + +############################################################################## +# Classes +############################################################################################################################################ + + +def convU(in_channels, out_channels,conv_layer, norm_layer, kernel_size=3, stride=1,dilation=1, bias=True): + return nn.Sequential( + SpectralNorm(conv_layer(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)), +# conv_layer(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias), +# nn.BatchNorm2d(out_channels), + nn.LeakyReLU(0.2), + SpectralNorm(conv_layer(out_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)), + ) +class MSDilateBlock(nn.Module): + def __init__(self, in_channels,conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, kernel_size=3, dilation=[1,1,1,1], bias=True): + super(MSDilateBlock, self).__init__() + self.conv1 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[0], bias=bias) + self.conv2 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[1], bias=bias) + self.conv3 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[2], bias=bias) + self.conv4 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[3], bias=bias) + self.convi = SpectralNorm(conv_layer(in_channels*4, in_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size-1)//2, bias=bias)) + def forward(self, x): + conv1 = self.conv1(x) + conv2 = self.conv2(x) + conv3 = self.conv3(x) + conv4 = self.conv4(x) + cat = torch.cat([conv1, conv2, conv3, conv4], 1) + out = self.convi(cat) + x + return out + +##############################UNetFace######################### +class AdaptiveInstanceNorm(nn.Module): + def __init__(self, in_channel): + super().__init__() + self.norm = nn.InstanceNorm2d(in_channel) + + def forward(self, input, style): + style_mean, style_std = calc_mean_std_4D(style) + out = self.norm(input) + size = input.size() + out = style_std.expand(size) * out + style_mean.expand(size) + return out + +class BlurFunctionBackward(Function): + @staticmethod + def forward(ctx, grad_output, kernel, kernel_flip): + ctx.save_for_backward(kernel, kernel_flip) + + grad_input = F.conv2d( + grad_output, kernel_flip, padding=1, groups=grad_output.shape[1] + ) + return grad_input + + @staticmethod + def backward(ctx, gradgrad_output): + kernel, kernel_flip = ctx.saved_tensors + + grad_input = F.conv2d( + gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1] + ) + return grad_input, None, None + + +class BlurFunction(Function): + @staticmethod + def forward(ctx, input, kernel, kernel_flip): + ctx.save_for_backward(kernel, kernel_flip) + + output = F.conv2d(input, kernel, padding=1, groups=input.shape[1]) + + return output + + @staticmethod + def backward(ctx, grad_output): + kernel, kernel_flip = ctx.saved_tensors + + grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip) + + return grad_input, None, None + +blur = BlurFunction.apply + + +class Blur(nn.Module): + def __init__(self, channel): + super().__init__() + + weight = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32) + weight = weight.view(1, 1, 3, 3) + weight = weight / weight.sum() + weight_flip = torch.flip(weight, [2, 3]) + + self.register_buffer('weight', weight.repeat(channel, 1, 1, 1)) + self.register_buffer('weight_flip', weight_flip.repeat(channel, 1, 1, 1)) + + def forward(self, input): + return blur(input, self.weight, self.weight_flip) + +class EqualLR: + def __init__(self, name): + self.name = name + + def compute_weight(self, module): + weight = getattr(module, self.name + '_orig') + fan_in = weight.data.size(1) * weight.data[0][0].numel() + return weight * sqrt(2 / fan_in) + @staticmethod + def apply(module, name): + fn = EqualLR(name) + + weight = getattr(module, name) + del module._parameters[name] + module.register_parameter(name + '_orig', nn.Parameter(weight.data)) + module.register_forward_pre_hook(fn) + + return fn + + def __call__(self, module, input): + weight = self.compute_weight(module) + setattr(module, self.name, weight) + +def equal_lr(module, name='weight'): + EqualLR.apply(module, name) + return module + +class EqualConv2d(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + conv = nn.Conv2d(*args, **kwargs) + conv.weight.data.normal_() + conv.bias.data.zero_() + self.conv = equal_lr(conv) + def forward(self, input): + return self.conv(input) + +class NoiseInjection(nn.Module): + def __init__(self, channel): + super().__init__() + self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1)) + def forward(self, image, noise): + return image + self.weight * noise + +class StyledUpBlock(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size=3, padding=1,upsample=False): + super().__init__() + if upsample: + self.conv1 = nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), + Blur(out_channel), + # EqualConv2d(in_channel, out_channel, kernel_size, padding=padding), + SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)), + nn.LeakyReLU(0.2), + ) + else: + self.conv1 = nn.Sequential( + Blur(in_channel), + # EqualConv2d(in_channel, out_channel, kernel_size, padding=padding) + SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)), + nn.LeakyReLU(0.2), + ) + self.convup = nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), + # EqualConv2d(out_channel, out_channel, kernel_size, padding=padding), + SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)), + nn.LeakyReLU(0.2), + # Blur(out_channel), + ) + # self.noise1 = equal_lr(NoiseInjection(out_channel)) + # self.adain1 = AdaptiveInstanceNorm(out_channel) + self.lrelu1 = nn.LeakyReLU(0.2) + + # self.conv2 = EqualConv2d(out_channel, out_channel, kernel_size, padding=padding) + # self.noise2 = equal_lr(NoiseInjection(out_channel)) + # self.adain2 = AdaptiveInstanceNorm(out_channel) + # self.lrelu2 = nn.LeakyReLU(0.2) + + self.ScaleModel1 = nn.Sequential( + # Blur(in_channel), + SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)), + # nn.Conv2d(in_channel,out_channel,3, 1, 1), + nn.LeakyReLU(0.2, True), + SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)) + # nn.Conv2d(out_channel, out_channel, 3, 1, 1) + ) + self.ShiftModel1 = nn.Sequential( + # Blur(in_channel), + SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)), + # nn.Conv2d(in_channel,out_channel,3, 1, 1), + nn.LeakyReLU(0.2, True), + SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), + nn.Sigmoid(), + # nn.Conv2d(out_channel, out_channel, 3, 1, 1) + ) + + def forward(self, input, style): + out = self.conv1(input) +# out = self.noise1(out, noise) + out = self.lrelu1(out) + + Shift1 = self.ShiftModel1(style) + Scale1 = self.ScaleModel1(style) + out = out * Scale1 + Shift1 + # out = self.adain1(out, style) + outup = self.convup(out) + + return outup + +############################################################################## +##Face Dictionary +############################################################################## +class VGGFeat(torch.nn.Module): + """ + Input: (B, C, H, W), RGB, [-1, 1] + """ + def __init__(self, weight_path='./weights/vgg19.pth'): + super().__init__() + self.model = models.vgg19(pretrained=False) + self.build_vgg_layers() + + self.model.load_state_dict(torch.load(weight_path)) + + self.register_parameter("RGB_mean", nn.Parameter(torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))) + self.register_parameter("RGB_std", nn.Parameter(torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))) + + # self.model.eval() + for param in self.model.parameters(): + param.requires_grad = False + + def build_vgg_layers(self): + vgg_pretrained_features = self.model.features + self.features = [] + # feature_layers = [0, 3, 8, 17, 26, 35] + feature_layers = [0, 8, 17, 26, 35] + for i in range(len(feature_layers)-1): + module_layers = torch.nn.Sequential() + for j in range(feature_layers[i], feature_layers[i+1]): + module_layers.add_module(str(j), vgg_pretrained_features[j]) + self.features.append(module_layers) + self.features = torch.nn.ModuleList(self.features) + + def preprocess(self, x): + x = (x + 1) / 2 + x = (x - self.RGB_mean) / self.RGB_std + if x.shape[3] < 224: + x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False) + return x + + def forward(self, x): + x = self.preprocess(x) + features = [] + for m in self.features: + # print(m) + x = m(x) + features.append(x) + return features + +def compute_sum(x, axis=None, keepdim=False): + if not axis: + axis = range(len(x.shape)) + for i in sorted(axis, reverse=True): + x = torch.sum(x, dim=i, keepdim=keepdim) + return x +def ToRGB(in_channel): + return nn.Sequential( + SpectralNorm(nn.Conv2d(in_channel,in_channel,3, 1, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(in_channel,3,3, 1, 1)) + ) + +def AttentionBlock(in_channel): + return nn.Sequential( + SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)) + ) + +class UNetDictFace(nn.Module): + def __init__(self, ngf=64, dictionary_path='./DictionaryCenter512'): + super().__init__() + + self.part_sizes = np.array([80,80,50,110]) # size for 512 + self.feature_sizes = np.array([256,128,64,32]) + self.channel_sizes = np.array([128,256,512,512]) + Parts = ['left_eye','right_eye','nose','mouth'] + self.Dict_256 = {} + self.Dict_128 = {} + self.Dict_64 = {} + self.Dict_32 = {} + for j,i in enumerate(Parts): + f_256 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_256_center.npy'.format(i)), allow_pickle=True)) + + f_256_reshape = f_256.reshape(f_256.size(0),self.channel_sizes[0],self.part_sizes[j]//2,self.part_sizes[j]//2) + max_256 = torch.max(torch.sqrt(compute_sum(torch.pow(f_256_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4])) + self.Dict_256[i] = f_256_reshape #/ max_256 + + f_128 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_128_center.npy'.format(i)), allow_pickle=True)) + + f_128_reshape = f_128.reshape(f_128.size(0),self.channel_sizes[1],self.part_sizes[j]//4,self.part_sizes[j]//4) + max_128 = torch.max(torch.sqrt(compute_sum(torch.pow(f_128_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4])) + self.Dict_128[i] = f_128_reshape #/ max_128 + + f_64 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_64_center.npy'.format(i)), allow_pickle=True)) + + f_64_reshape = f_64.reshape(f_64.size(0),self.channel_sizes[2],self.part_sizes[j]//8,self.part_sizes[j]//8) + max_64 = torch.max(torch.sqrt(compute_sum(torch.pow(f_64_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4])) + self.Dict_64[i] = f_64_reshape #/ max_64 + + f_32 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_32_center.npy'.format(i)), allow_pickle=True)) + + f_32_reshape = f_32.reshape(f_32.size(0),self.channel_sizes[3],self.part_sizes[j]//16,self.part_sizes[j]//16) + max_32 = torch.max(torch.sqrt(compute_sum(torch.pow(f_32_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4])) + self.Dict_32[i] = f_32_reshape #/ max_32 + + self.le_256 = AttentionBlock(128) + self.le_128 = AttentionBlock(256) + self.le_64 = AttentionBlock(512) + self.le_32 = AttentionBlock(512) + + self.re_256 = AttentionBlock(128) + self.re_128 = AttentionBlock(256) + self.re_64 = AttentionBlock(512) + self.re_32 = AttentionBlock(512) + + self.no_256 = AttentionBlock(128) + self.no_128 = AttentionBlock(256) + self.no_64 = AttentionBlock(512) + self.no_32 = AttentionBlock(512) + + self.mo_256 = AttentionBlock(128) + self.mo_128 = AttentionBlock(256) + self.mo_64 = AttentionBlock(512) + self.mo_32 = AttentionBlock(512) + + #norm + self.VggExtract = VGGFeat() + + ###################### + self.MSDilate = MSDilateBlock(ngf*8, dilation = [4,3,2,1]) # + + self.up0 = StyledUpBlock(ngf*8,ngf*8) + self.up1 = StyledUpBlock(ngf*8, ngf*4) # + self.up2 = StyledUpBlock(ngf*4, ngf*2) # + self.up3 = StyledUpBlock(ngf*2, ngf) # + self.up4 = nn.Sequential( # 128 + # nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), + SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)), + # nn.BatchNorm2d(32), + nn.LeakyReLU(0.2), + UpResBlock(ngf), + UpResBlock(ngf), + # SpectralNorm(nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)), + nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1), + nn.Tanh() + ) + self.to_rgb0 = ToRGB(ngf*8) + self.to_rgb1 = ToRGB(ngf*4) + self.to_rgb2 = ToRGB(ngf*2) + self.to_rgb3 = ToRGB(ngf*1) + + # for param in self.BlurInputConv.parameters(): + # param.requires_grad = False + + def forward(self,input, part_locations): + + VggFeatures = self.VggExtract(input) + # for b in range(input.size(0)): + b = 0 + UpdateVggFeatures = [] + for i, f_size in enumerate(self.feature_sizes): + cur_feature = VggFeatures[i] + update_feature = cur_feature.clone() #* 0 + cur_part_sizes = self.part_sizes // (512/f_size) + + dicts_feature = getattr(self, 'Dict_'+str(f_size)) + LE_Dict_feature = dicts_feature['left_eye'].to(input) + RE_Dict_feature = dicts_feature['right_eye'].to(input) + NO_Dict_feature = dicts_feature['nose'].to(input) + MO_Dict_feature = dicts_feature['mouth'].to(input) + + le_location = (part_locations[0][b] // (512/f_size)).int() + re_location = (part_locations[1][b] // (512/f_size)).int() + no_location = (part_locations[2][b] // (512/f_size)).int() + mo_location = (part_locations[3][b] // (512/f_size)).int() + + LE_feature = cur_feature[:,:,le_location[1]:le_location[3],le_location[0]:le_location[2]].clone() + RE_feature = cur_feature[:,:,re_location[1]:re_location[3],re_location[0]:re_location[2]].clone() + NO_feature = cur_feature[:,:,no_location[1]:no_location[3],no_location[0]:no_location[2]].clone() + MO_feature = cur_feature[:,:,mo_location[1]:mo_location[3],mo_location[0]:mo_location[2]].clone() + + #resize + LE_feature_resize = F.interpolate(LE_feature,(LE_Dict_feature.size(2),LE_Dict_feature.size(3)),mode='bilinear',align_corners=False) + RE_feature_resize = F.interpolate(RE_feature,(RE_Dict_feature.size(2),RE_Dict_feature.size(3)),mode='bilinear',align_corners=False) + NO_feature_resize = F.interpolate(NO_feature,(NO_Dict_feature.size(2),NO_Dict_feature.size(3)),mode='bilinear',align_corners=False) + MO_feature_resize = F.interpolate(MO_feature,(MO_Dict_feature.size(2),MO_Dict_feature.size(3)),mode='bilinear',align_corners=False) + + LE_Dict_feature_norm = adaptive_instance_normalization_4D(LE_Dict_feature, LE_feature_resize) + RE_Dict_feature_norm = adaptive_instance_normalization_4D(RE_Dict_feature, RE_feature_resize) + NO_Dict_feature_norm = adaptive_instance_normalization_4D(NO_Dict_feature, NO_feature_resize) + MO_Dict_feature_norm = adaptive_instance_normalization_4D(MO_Dict_feature, MO_feature_resize) + + LE_score = F.conv2d(LE_feature_resize, LE_Dict_feature_norm) + + LE_score = F.softmax(LE_score.view(-1),dim=0) + LE_index = torch.argmax(LE_score) + LE_Swap_feature = F.interpolate(LE_Dict_feature_norm[LE_index:LE_index+1], (LE_feature.size(2), LE_feature.size(3))) + + LE_Attention = getattr(self, 'le_'+str(f_size))(LE_Swap_feature-LE_feature) + LE_Att_feature = LE_Attention * LE_Swap_feature + + + RE_score = F.conv2d(RE_feature_resize, RE_Dict_feature_norm) + RE_score = F.softmax(RE_score.view(-1),dim=0) + RE_index = torch.argmax(RE_score) + RE_Swap_feature = F.interpolate(RE_Dict_feature_norm[RE_index:RE_index+1], (RE_feature.size(2), RE_feature.size(3))) + + RE_Attention = getattr(self, 're_'+str(f_size))(RE_Swap_feature-RE_feature) + RE_Att_feature = RE_Attention * RE_Swap_feature + + NO_score = F.conv2d(NO_feature_resize, NO_Dict_feature_norm) + NO_score = F.softmax(NO_score.view(-1),dim=0) + NO_index = torch.argmax(NO_score) + NO_Swap_feature = F.interpolate(NO_Dict_feature_norm[NO_index:NO_index+1], (NO_feature.size(2), NO_feature.size(3))) + + NO_Attention = getattr(self, 'no_'+str(f_size))(NO_Swap_feature-NO_feature) + NO_Att_feature = NO_Attention * NO_Swap_feature + + + MO_score = F.conv2d(MO_feature_resize, MO_Dict_feature_norm) + MO_score = F.softmax(MO_score.view(-1),dim=0) + MO_index = torch.argmax(MO_score) + MO_Swap_feature = F.interpolate(MO_Dict_feature_norm[MO_index:MO_index+1], (MO_feature.size(2), MO_feature.size(3))) + + MO_Attention = getattr(self, 'mo_'+str(f_size))(MO_Swap_feature-MO_feature) + MO_Att_feature = MO_Attention * MO_Swap_feature + + update_feature[:,:,le_location[1]:le_location[3],le_location[0]:le_location[2]] = LE_Att_feature + LE_feature + update_feature[:,:,re_location[1]:re_location[3],re_location[0]:re_location[2]] = RE_Att_feature + RE_feature + update_feature[:,:,no_location[1]:no_location[3],no_location[0]:no_location[2]] = NO_Att_feature + NO_feature + update_feature[:,:,mo_location[1]:mo_location[3],mo_location[0]:mo_location[2]] = MO_Att_feature + MO_feature + + UpdateVggFeatures.append(update_feature) + + fea_vgg = self.MSDilate(VggFeatures[3]) + #new version + fea_up0 = self.up0(fea_vgg, UpdateVggFeatures[3]) + # out1 = F.interpolate(fea_up0,(512,512)) + # out1 = self.to_rgb0(out1) + + fea_up1 = self.up1( fea_up0, UpdateVggFeatures[2]) # + # out2 = F.interpolate(fea_up1,(512,512)) + # out2 = self.to_rgb1(out2) + + fea_up2 = self.up2(fea_up1, UpdateVggFeatures[1]) # + # out3 = F.interpolate(fea_up2,(512,512)) + # out3 = self.to_rgb2(out3) + + fea_up3 = self.up3(fea_up2, UpdateVggFeatures[0]) # + # out4 = F.interpolate(fea_up3,(512,512)) + # out4 = self.to_rgb3(out4) + + output = self.up4(fea_up3) # + + + return output #+ out4 + out3 + out2 + out1 + #0 128 * 256 * 256 + #1 256 * 128 * 128 + #2 512 * 64 * 64 + #3 512 * 32 * 32 + + +class UpResBlock(nn.Module): + def __init__(self, dim, conv_layer = nn.Conv2d, norm_layer = nn.BatchNorm2d): + super(UpResBlock, self).__init__() + self.Model = nn.Sequential( + # SpectralNorm(conv_layer(dim, dim, 3, 1, 1)), + conv_layer(dim, dim, 3, 1, 1), + # norm_layer(dim), + nn.LeakyReLU(0.2,True), + # SpectralNorm(conv_layer(dim, dim, 3, 1, 1)), + conv_layer(dim, dim, 3, 1, 1), + ) + def forward(self, x): + out = x + self.Model(x) + return out + +class VggClassNet(nn.Module): + def __init__(self, select_layer = ['0','5','10','19']): + super(VggClassNet, self).__init__() + self.select = select_layer + self.vgg = models.vgg19(pretrained=True).features + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + features = [] + for name, layer in self.vgg._modules.items(): + x = layer(x) + if name in self.select: + features.append(x) + return features + + +if __name__ == '__main__': + print('this is network') + + diff --git a/models/test_model.py b/models/test_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ae702e5b9cc19eb36c208d72ab5e07c6c38f7eca --- /dev/null +++ b/models/test_model.py @@ -0,0 +1,48 @@ +from .base_model import BaseModel +from . import networks +import torch +import numpy as np +import torchvision.transforms as transforms +import PIL + +import torch.nn.functional as F + +class TestModel(BaseModel): + def name(self): + return 'TestModel' + + @staticmethod + def modify_commandline_options(parser, is_train=True): + assert not is_train, 'TestModel cannot be used in train mode' + parser.set_defaults(dataset_mode='aligned') + + parser.add_argument('--model_suffix', type=str, default='', + help='In checkpoints_dir, [which_epoch]_net_G[model_suffix].pth will' + ' be loaded as the generator of TestModel') + return parser + + def initialize(self, opt): + assert(not opt.isTrain) + BaseModel.initialize(self, opt) + + # specify the training losses you want to print out. The program will call base_model.get_current_losses + self.loss_names = [] + # specify the images you want to save/display. The program will call base_model.get_current_visuals + self.visual_names = ['fake_A','real_A'] + self.model_names = ['G'] + + self.netG = networks.define_G('UNetDictFace',self.gpu_ids) + + def set_input(self, input): + self.real_A = input['A'].to(self.device) #degraded img + self.real_C = input['C'].to(self.device) #groundtruth + self.image_paths = input['A_paths'] + self.Part_locations = input['Part_locations'] + + def forward(self): + + self.fake_A = self.netG(self.real_A, self.Part_locations) # + # try: + # self.fake_A = self.netG(self.real_A, self.Part_locations) #生成图 + # except: + # self.fake_A = self.real_A diff --git a/options/__pycache__/base_options.cpython-38.pyc b/options/__pycache__/base_options.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3f35f849babac414bc3258bb84209d8febe391b Binary files /dev/null and b/options/__pycache__/base_options.cpython-38.pyc differ diff --git a/options/__pycache__/test_options.cpython-38.pyc b/options/__pycache__/test_options.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4164c89c22506a57a45f76d195bdf290850137a Binary files /dev/null and b/options/__pycache__/test_options.cpython-38.pyc differ diff --git a/options/base_options.py b/options/base_options.py new file mode 100644 index 0000000000000000000000000000000000000000..ad99c47f7a17928280899abc138b1c9ff436e693 --- /dev/null +++ b/options/base_options.py @@ -0,0 +1,103 @@ +import argparse +import os +from util import util +import torch +import models +import data + + +class BaseOptions(): + def __init__(self): + self.initialized = False + + def initialize(self, parser): + + parser.add_argument('--batchSize', type=int, default=2, help='input batch size') + parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') + parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') + parser.add_argument('--name', type=str, default='facefh_dictionary', help='name of the experiment. It decides where to store samples and models') + parser.add_argument('--model', type=str, default='faceDict', help='chooses which model to use. cycle_gan, pix2pix, test') + parser.add_argument('--which_direction', type=str, default='BtoA', help='AtoB or BtoA') + parser.add_argument('--nThreads', default=8, type=int, help='# threads for loading data') + parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') + parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') + parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') + parser.add_argument('--resize_or_crop', type=str, default='degradation', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') + parser.add_argument('--init_type', type=str, default='kaiming', help='network initialization [normal|xavier|kaiming|orthogonal]') + parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') + parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') + parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{which_model_netG}_size{loadSize}') + self.initialized = True + return parser + + def gather_options(self): + # initialize parser with basic options + if not self.initialized: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = self.initialize(parser) + + # get the basic options + + opt, _ = parser.parse_known_args() + # modify model-related parser options + model_name = opt.model + model_option_setter = models.get_option_setter(model_name) + parser = model_option_setter(parser, self.isTrain) + + opt, _ = parser.parse_known_args() # parse again with the new defaults + + # modify dataset-related parser options + dataset_name = opt.dataset_mode + + dataset_option_setter = data.get_option_setter(dataset_name) + parser = dataset_option_setter(parser, self.isTrain) + + self.parser = parser + + return parser.parse_args() + + def print_options(self, opt): + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + + # save to the disk + expr_dir = os.path.join(opt.checkpoints_dir, opt.name) + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, 'opt.txt') + with open(file_name, 'wt') as opt_file: + opt_file.write(message) + opt_file.write('\n') + + def parse(self): + + opt = self.gather_options() + opt.isTrain = self.isTrain # train or test + + # process opt.suffix + if opt.suffix: + suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' + opt.name = opt.name + suffix + + # self.print_options(opt) + + # set gpu ids + str_ids = opt.gpu_ids.split(',') + opt.gpu_ids = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + opt.gpu_ids.append(id) + if len(opt.gpu_ids) > 0: + torch.cuda.set_device(opt.gpu_ids[0]) + + self.opt = opt + return self.opt diff --git a/options/test_options.py b/options/test_options.py new file mode 100644 index 0000000000000000000000000000000000000000..8f0eeb902981c81da164b7c1a82e6f57ef75722f --- /dev/null +++ b/options/test_options.py @@ -0,0 +1,20 @@ +from .base_options import BaseOptions + + +class TestOptions(BaseOptions): + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) + # parser.add_argument('--dataroot', type=str, default='/home/Data/AllDataImages/2018_FaceFH', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') + parser.add_argument('--dataroot', type=str, default='', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') + parser.add_argument('--phase', type=str, default='', help='train, val, test, etc') + parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') + parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + parser.set_defaults(model='test') + + parser.add_argument('--test_path', type=str, default='./TestData/TestWhole', help='test images path') + parser.add_argument('--results_dir', type=str, default='./Results/TestWholeResults', help='saves results here.') + parser.add_argument('--upscale_factor', type=int, default=4, help='upscale factor for the whole input image (not for face)') + + + self.isTrain = False + return parser diff --git a/packages/FFHQ_template.npy b/packages/FFHQ_template.npy new file mode 100644 index 0000000000000000000000000000000000000000..a3cf332405b244e0b55c93c599e7d7753cf52d83 Binary files /dev/null and b/packages/FFHQ_template.npy differ diff --git a/packages/mmod_human_face_detector.dat b/packages/mmod_human_face_detector.dat new file mode 100644 index 0000000000000000000000000000000000000000..f112a0a45dbda96080352c45f615d00ddd4f130c Binary files /dev/null and b/packages/mmod_human_face_detector.dat differ diff --git a/packages/shape_predictor_5_face_landmarks.dat b/packages/shape_predictor_5_face_landmarks.dat new file mode 100644 index 0000000000000000000000000000000000000000..67878ed3894d929c5e03bd1ad2d931cc6d745ee7 Binary files /dev/null and b/packages/shape_predictor_5_face_landmarks.dat differ diff --git a/sync_batchnorm/__init__.py b/sync_batchnorm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a10989fbef38485841fdbd1af72f53062a7b556d --- /dev/null +++ b/sync_batchnorm/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- +# File : __init__.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d +from .batchnorm import patch_sync_batchnorm, convert_model +from .replicate import DataParallelWithCallback, patch_replication_callback diff --git a/sync_batchnorm/batchnorm.py b/sync_batchnorm/batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..24c0950707d07102d06b5fdf64c74a347ec3186c --- /dev/null +++ b/sync_batchnorm/batchnorm.py @@ -0,0 +1,394 @@ +# -*- coding: utf-8 -*- +# File : batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections +import contextlib + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm + +try: + from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast +except ImportError: + ReduceAddCoalesced = Broadcast = None + +try: + from jactorch.parallel.comm import SyncMaster + from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback +except ImportError: + from .comm import SyncMaster + from .replicate import DataParallelWithCallback + +__all__ = [ + 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d', + 'patch_sync_batchnorm', 'convert_model' +] + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dimensions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) + + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): + assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.' + + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + def forward(self, input): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + + # Resize the input to (B, C, -1). + input_shape = input.size() + input = input.view(input.size(0), self.num_features, -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + + # Reduce-and-broadcast the statistics. + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # Compute the output. + if self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + + # Always using same "device order" makes the ReduceAdd operation faster. + # Thanks to:: Tete Xiao (http://tetexiao.com/) + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) + + return outputs + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + if hasattr(torch, 'no_grad'): + with torch.no_grad(): + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + else: + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + + return mean, bias_var.clamp(self.eps) ** -0.5 + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm1d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm2d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm3d, self)._check_input_dim(input) + + +@contextlib.contextmanager +def patch_sync_batchnorm(): + import torch.nn as nn + + backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d + + nn.BatchNorm1d = SynchronizedBatchNorm1d + nn.BatchNorm2d = SynchronizedBatchNorm2d + nn.BatchNorm3d = SynchronizedBatchNorm3d + + yield + + nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup + + +def convert_model(module): + """Traverse the input module and its child recursively + and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d + to SynchronizedBatchNorm*N*d + + Args: + module: the input module needs to be convert to SyncBN model + + Examples: + >>> import torch.nn as nn + >>> import torchvision + >>> # m is a standard pytorch model + >>> m = torchvision.models.resnet18(True) + >>> m = nn.DataParallel(m) + >>> # after convert, m is using SyncBN + >>> m = convert_model(m) + """ + if isinstance(module, torch.nn.DataParallel): + mod = module.module + mod = convert_model(mod) + mod = DataParallelWithCallback(mod) + return mod + + mod = module + for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d, + torch.nn.modules.batchnorm.BatchNorm2d, + torch.nn.modules.batchnorm.BatchNorm3d], + [SynchronizedBatchNorm1d, + SynchronizedBatchNorm2d, + SynchronizedBatchNorm3d]): + if isinstance(module, pth_module): + mod = sync_module(module.num_features, module.eps, module.momentum, module.affine) + mod.running_mean = module.running_mean + mod.running_var = module.running_var + if module.affine: + mod.weight.data = module.weight.data.clone().detach() + mod.bias.data = module.bias.data.clone().detach() + + for name, child in module.named_children(): + mod.add_module(name, convert_model(child)) + + return mod diff --git a/sync_batchnorm/batchnorm_reimpl.py b/sync_batchnorm/batchnorm_reimpl.py new file mode 100644 index 0000000000000000000000000000000000000000..18145c3353e13d482c492ae46df91a537669fca0 --- /dev/null +++ b/sync_batchnorm/batchnorm_reimpl.py @@ -0,0 +1,74 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# File : batchnorm_reimpl.py +# Author : acgtyrant +# Date : 11/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import torch +import torch.nn as nn +import torch.nn.init as init + +__all__ = ['BatchNorm2dReimpl'] + + +class BatchNorm2dReimpl(nn.Module): + """ + A re-implementation of batch normalization, used for testing the numerical + stability. + + Author: acgtyrant + See also: + https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 + """ + def __init__(self, num_features, eps=1e-5, momentum=0.1): + super().__init__() + + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.weight = nn.Parameter(torch.empty(num_features)) + self.bias = nn.Parameter(torch.empty(num_features)) + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_running_stats(self): + self.running_mean.zero_() + self.running_var.fill_(1) + + def reset_parameters(self): + self.reset_running_stats() + init.uniform_(self.weight) + init.zeros_(self.bias) + + def forward(self, input_): + batchsize, channels, height, width = input_.size() + numel = batchsize * height * width + input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) + sum_ = input_.sum(1) + sum_of_square = input_.pow(2).sum(1) + mean = sum_ / numel + sumvar = sum_of_square - sum_ * mean + + self.running_mean = ( + (1 - self.momentum) * self.running_mean + + self.momentum * mean.detach() + ) + unbias_var = sumvar / (numel - 1) + self.running_var = ( + (1 - self.momentum) * self.running_var + + self.momentum * unbias_var.detach() + ) + + bias_var = sumvar / numel + inv_std = 1 / (bias_var + self.eps).pow(0.5) + output = ( + (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * + self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) + + return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() + diff --git a/sync_batchnorm/comm.py b/sync_batchnorm/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..922f8c4a3adaa9b32fdcaef09583be03b0d7eb2b --- /dev/null +++ b/sync_batchnorm/comm.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def __getstate__(self): + return {'master_callback': self._master_callback} + + def __setstate__(self, state): + self.__init__(state['master_callback']) + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/sync_batchnorm/replicate.py b/sync_batchnorm/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..b71c7b8ed51a1d6c55b1f753bdd8d90bad79bd06 --- /dev/null +++ b/sync_batchnorm/replicate.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# File : replicate.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate diff --git a/sync_batchnorm/unittest.py b/sync_batchnorm/unittest.py new file mode 100644 index 0000000000000000000000000000000000000000..bed56f1caa929ac3e9a57c583f8d3e42624f58be --- /dev/null +++ b/sync_batchnorm/unittest.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# File : unittest.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import unittest +import torch + + +class TorchTestCase(unittest.TestCase): + def assertTensorClose(self, x, y): + adiff = float((x - y).abs().max()) + if (y == 0).all(): + rdiff = 'NaN' + else: + rdiff = float((adiff / y).abs().max()) + + message = ( + 'Tensor close check failed\n' + 'adiff={}\n' + 'rdiff={}\n' + ).format(adiff, rdiff) + self.assertTrue(torch.allclose(x, y), message) + diff --git a/util/Loss.py b/util/Loss.py new file mode 100644 index 0000000000000000000000000000000000000000..4f3d9bc582b79e34ab30e8f3ed03f4cf2bb8ccad --- /dev/null +++ b/util/Loss.py @@ -0,0 +1,41 @@ +import torch.nn as nn +import torch +import torch.nn.functional as F + + +class TVLoss(nn.Module): + def __init__(self,TVLoss_weight=1): + super(TVLoss,self).__init__() + self.TVLoss_weight = TVLoss_weight + + def forward(self,x): + batch_size = x.size()[0] + h_x = x.size()[2] + w_x = x.size()[3] + count_h = self._tensor_size(x[:,:,1:,:]) + count_w = self._tensor_size(x[:,:,:,1:]) + h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum() + w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum() + return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size + + def _tensor_size(self,t): + return t.size()[1]*t.size()[2]*t.size()[3] + + +class hinge_loss(nn.Module): + def __init__(self): + super(hinge_loss, self).__init__() + + def forward(self, dis_fake, dis_real): + loss_real = torch.mean(F.relu(1. - dis_real)) + loss_fake = torch.mean(F.relu(1. + dis_fake)) + return loss_real + loss_fake + + +class hinge_loss_G(nn.Module): + def __init__(self): + super(hinge_loss_G, self).__init__() + + def forward(self, dis_fake): + loss_fake = -torch.mean(dis_fake) + return loss_fake \ No newline at end of file diff --git a/util/ROI_MLS.py b/util/ROI_MLS.py new file mode 100644 index 0000000000000000000000000000000000000000..0987909e797831c3b350150a4706016b16159bc3 --- /dev/null +++ b/util/ROI_MLS.py @@ -0,0 +1,542 @@ +import numpy as np +import torch +import torch.nn.functional as F + +#mls affine inv + +def mls_affine_deformation_inv(image, height, width, channel, p, q, alpha=1.0, density=1.0): + ''' Affine inverse deformation + ### Params: + * image - ndarray: original image + * p - ndarray: an array with size [n, 2], original control points + * q - ndarray: an array with size [n, 2], final control points + * alpha - float: parameter used by weights + * density - float: density of the grids + ### Return: + A deformed image. + ''' + # height = image.shape[0] + # width = image.shape[1] + # Change (x, y) to (row, col) + q = q[:, [1, 0]] + p = p[:, [1, 0]] + + # Make grids on the original image + gridX = np.linspace(0, width, num=int(width*density), endpoint=False) + gridY = np.linspace(0, height, num=int(height*density), endpoint=False) + vy, vx = np.meshgrid(gridX, gridY) + grow = vx.shape[0] # grid rows + gcol = vx.shape[1] # grid cols + ctrls = p.shape[0] # control points + + # Compute + reshaped_p = p.reshape(ctrls, 2, 1, 1) # [ctrls, 2, 1, 1] + reshaped_q = q.reshape((ctrls, 2, 1, 1)) # [ctrls, 2, 1, 1] + reshaped_v = np.vstack((vx.reshape(1, grow, gcol), vy.reshape(1, grow, gcol))) # [2, grow, gcol] + + w = 1.0 / np.sum((reshaped_p - reshaped_v) ** 2, axis=1)**alpha # [ctrls, grow, gcol] + w[w == np.inf] = 2**31 - 1 + pstar = np.sum(w * reshaped_p.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol] + phat = reshaped_p - pstar # [ctrls, 2, grow, gcol] + qstar = np.sum(w * reshaped_q.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol] + qhat = reshaped_q - qstar # [ctrls, 2, grow, gcol] + + reshaped_phat = phat.reshape(ctrls, 2, 1, grow, gcol) # [ctrls, 2, 1, grow, gcol] + reshaped_phat2 = phat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 2, 1, grow, gcol] + reshaped_qhat = qhat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] + reshaped_w = w.reshape(ctrls, 1, 1, grow, gcol) # [ctrls, 1, 1, grow, gcol] + pTwq = np.sum(reshaped_phat * reshaped_w * reshaped_qhat, axis=0) # [2, 2, grow, gcol] + try: + inv_pTwq = np.linalg.inv(pTwq.transpose(2, 3, 0, 1)) # [grow, gcol, 2, 2] + flag = False + except np.linalg.linalg.LinAlgError: + flag = True + det = np.linalg.det(pTwq.transpose(2, 3, 0, 1)) # [grow, gcol] + det[det < 1e-8] = np.inf + reshaped_det = det.reshape(1, 1, grow, gcol) # [1, 1, grow, gcol] + adjoint = pTwq[[[1, 0], [1, 0]], [[1, 1], [0, 0]], :, :] # [2, 2, grow, gcol] + adjoint[[0, 1], [1, 0], :, :] = -adjoint[[0, 1], [1, 0], :, :] # [2, 2, grow, gcol] + inv_pTwq = (adjoint / reshaped_det).transpose(2, 3, 0, 1) # [grow, gcol, 2, 2] + mul_left = reshaped_v - qstar # [2, grow, gcol] + reshaped_mul_left = mul_left.reshape(1, 2, grow, gcol).transpose(2, 3, 0, 1) # [grow, gcol, 1, 2] + mul_right = np.sum(reshaped_phat * reshaped_w * reshaped_phat2, axis=0) # [2, 2, grow, gcol] + reshaped_mul_right =mul_right.transpose(2, 3, 0, 1) # [grow, gcol, 2, 2] + temp = np.matmul(np.matmul(reshaped_mul_left, inv_pTwq), reshaped_mul_right) # [grow, gcol, 1, 2] + reshaped_temp = temp.reshape(grow, gcol, 2).transpose(2, 0, 1) # [2, grow, gcol] + + # Get final image transfomer -- 3-D array + transformers = reshaped_temp + pstar # [2, grow, gcol] + + # Correct the points where pTwp is singular + if flag: + blidx = det == np.inf # bool index + transformers[0][blidx] = vx[blidx] + qstar[0][blidx] - pstar[0][blidx] + transformers[1][blidx] = vy[blidx] + qstar[1][blidx] - pstar[1][blidx] + + # Removed the points outside the border + transformers[transformers < 0] = 0 + transformers[0][transformers[0] > height - 1] = 0 + transformers[1][transformers[1] > width - 1] = 0 + + # Mapping original image + transformed_image = image[tuple(transformers.astype(np.int16))] # [grow, gcol] + + # Rescale image + # transformed_image = rescale(transformed_image, scale=1.0 / density, mode='reflect') + + return transformers.astype(np.float), transformed_image + + +def mls_affine_deformation_inv_final(height, width, channel, p, q, alpha=1.0, density=1.0): + ''' Affine inverse deformation + ### Params: + * image - ndarray: original image + * p - ndarray: an array with size [n, 2], original control points + * q - ndarray: an array with size [n, 2], final control points + * alpha - float: parameter used by weights + * density - float: density of the grids + ### Return: + A deformed image. + ''' + # height = image.shape[0] + # width = image.shape[1] + # Change (x, y) to (row, col) + q = q[:, [1, 0]] + p = p[:, [1, 0]] + + # Make grids on the original image + gridX = np.linspace(0, width, num=int(width*density), endpoint=False) + gridY = np.linspace(0, height, num=int(height*density), endpoint=False) + vy, vx = np.meshgrid(gridX, gridY) + grow = vx.shape[0] # grid rows + gcol = vx.shape[1] # grid cols + ctrls = p.shape[0] # control points + + # Compute + reshaped_p = p.reshape(ctrls, 2, 1, 1) # [ctrls, 2, 1, 1] + reshaped_q = q.reshape((ctrls, 2, 1, 1)) # [ctrls, 2, 1, 1] + reshaped_v = np.vstack((vx.reshape(1, grow, gcol), vy.reshape(1, grow, gcol))) # [2, grow, gcol] + + w = 1.0 / np.sum((reshaped_p - reshaped_v) ** 2, axis=1)**alpha # [ctrls, grow, gcol] + w[w == np.inf] = 2**31 - 1 + pstar = np.sum(w * reshaped_p.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol] + phat = reshaped_p - pstar # [ctrls, 2, grow, gcol] + qstar = np.sum(w * reshaped_q.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol] + qhat = reshaped_q - qstar # [ctrls, 2, grow, gcol] + + reshaped_phat = phat.reshape(ctrls, 2, 1, grow, gcol) # [ctrls, 2, 1, grow, gcol] + reshaped_phat2 = phat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 2, 1, grow, gcol] + reshaped_qhat = qhat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] + reshaped_w = w.reshape(ctrls, 1, 1, grow, gcol) # [ctrls, 1, 1, grow, gcol] + pTwq = np.sum(reshaped_phat * reshaped_w * reshaped_qhat, axis=0) # [2, 2, grow, gcol] + try: + inv_pTwq = np.linalg.inv(pTwq.transpose(2, 3, 0, 1)) # [grow, gcol, 2, 2] + flag = False + except np.linalg.linalg.LinAlgError: + flag = True + det = np.linalg.det(pTwq.transpose(2, 3, 0, 1)) # [grow, gcol] + det[det < 1e-8] = np.inf + reshaped_det = det.reshape(1, 1, grow, gcol) # [1, 1, grow, gcol] + adjoint = pTwq[[[1, 0], [1, 0]], [[1, 1], [0, 0]], :, :] # [2, 2, grow, gcol] + adjoint[[0, 1], [1, 0], :, :] = -adjoint[[0, 1], [1, 0], :, :] # [2, 2, grow, gcol] + inv_pTwq = (adjoint / reshaped_det).transpose(2, 3, 0, 1) # [grow, gcol, 2, 2] + mul_left = reshaped_v - qstar # [2, grow, gcol] + reshaped_mul_left = mul_left.reshape(1, 2, grow, gcol).transpose(2, 3, 0, 1) # [grow, gcol, 1, 2] + mul_right = np.sum(reshaped_phat * reshaped_w * reshaped_phat2, axis=0) # [2, 2, grow, gcol] + reshaped_mul_right =mul_right.transpose(2, 3, 0, 1) # [grow, gcol, 2, 2] + temp = np.matmul(np.matmul(reshaped_mul_left, inv_pTwq), reshaped_mul_right) # [grow, gcol, 1, 2] + reshaped_temp = temp.reshape(grow, gcol, 2).transpose(2, 0, 1) # [2, grow, gcol] + + # Get final image transfomer -- 3-D array + transformers = reshaped_temp + pstar # [2, grow, gcol] + + # Correct the points where pTwp is singular + if flag: + blidx = det == np.inf # bool index + transformers[0][blidx] = vx[blidx] + qstar[0][blidx] - pstar[0][blidx] + transformers[1][blidx] = vy[blidx] + qstar[1][blidx] - pstar[1][blidx] + + # Removed the points outside the border + transformers[transformers < 0] = 0 + transformers[0][transformers[0] > height - 1] = 0 + transformers[1][transformers[1] > width - 1] = 0 + + return transformers + +def mls_similarity_deformation_inv(image, height, width, channel, p, q, alpha=1.0, density=1.0): + ''' Similarity inverse deformation + ### Params: + * image - ndarray: original image + * p - ndarray: an array with size [n, 2], original control points + * q - ndarray: an array with size [n, 2], final control points + * alpha - float: parameter used by weights + * density - float: density of the grids + ### Return: + A deformed image. + ''' + height = image.shape[0] + width = image.shape[1] + # Change (x, y) to (row, col) + q = q[:, [1, 0]] + p = p[:, [1, 0]] + + # Make grids on the original image + gridX = np.linspace(0, width, num=int(width*density), endpoint=False) + gridY = np.linspace(0, height, num=int(height*density), endpoint=False) + vy, vx = np.meshgrid(gridX, gridY) + grow = vx.shape[0] # grid rows + gcol = vx.shape[1] # grid cols + ctrls = p.shape[0] # control points + + # Compute + reshaped_p = p.reshape(ctrls, 2, 1, 1) # [ctrls, 2, 1, 1] + reshaped_q = q.reshape((ctrls, 2, 1, 1)) # [ctrls, 2, 1, 1] + reshaped_v = np.vstack((vx.reshape(1, grow, gcol), vy.reshape(1, grow, gcol))) # [2, grow, gcol] + + w = 1.0 / np.sum((reshaped_p - reshaped_v) ** 2, axis=1)**alpha # [ctrls, grow, gcol] + w[w == np.inf] = 2**31 - 1 + pstar = np.sum(w * reshaped_p.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol] + phat = reshaped_p - pstar # [ctrls, 2, grow, gcol] + qstar = np.sum(w * reshaped_q.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol] + qhat = reshaped_q - qstar # [ctrls, 2, grow, gcol] + reshaped_phat1 = phat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] + reshaped_phat2 = phat.reshape(ctrls, 2, 1, grow, gcol) # [ctrls, 2, 1, grow, gcol] + reshaped_qhat = qhat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] + reshaped_w = w.reshape(ctrls, 1, 1, grow, gcol) # [ctrls, 1, 1, grow, gcol] + + mu = np.sum(np.matmul(reshaped_w.transpose(0, 3, 4, 1, 2) * + reshaped_phat1.transpose(0, 3, 4, 1, 2), + reshaped_phat2.transpose(0, 3, 4, 1, 2)), axis=0) # [grow, gcol, 1, 1] + reshaped_mu = mu.reshape(1, grow, gcol) # [1, grow, gcol] + neg_phat_verti = phat[:, [1, 0],...] # [ctrls, 2, grow, gcol] + neg_phat_verti[:, 1,...] = -neg_phat_verti[:, 1,...] + reshaped_neg_phat_verti = neg_phat_verti.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] + mul_right = np.concatenate((reshaped_phat1, reshaped_neg_phat_verti), axis=1) # [ctrls, 2, 2, grow, gcol] + mul_left = reshaped_qhat * reshaped_w # [ctrls, 1, 2, grow, gcol] + Delta = np.sum(np.matmul(mul_left.transpose(0, 3, 4, 1, 2), + mul_right.transpose(0, 3, 4, 1, 2)), + axis=0).transpose(0, 1, 3, 2) # [grow, gcol, 2, 1] + Delta_verti = Delta[...,[1, 0],:] # [grow, gcol, 2, 1] + Delta_verti[...,0,:] = -Delta_verti[...,0,:] + B = np.concatenate((Delta, Delta_verti), axis=3) # [grow, gcol, 2, 2] + try: + inv_B = np.linalg.inv(B) # [grow, gcol, 2, 2] + flag = False + except np.linalg.linalg.LinAlgError: + flag = True + det = np.linalg.det(B) # [grow, gcol] + det[det < 1e-8] = np.inf + reshaped_det = det.reshape(grow, gcol, 1, 1) # [grow, gcol, 1, 1] + adjoint = B[:,:,[[1, 0], [1, 0]], [[1, 1], [0, 0]]] # [grow, gcol, 2, 2] + adjoint[:,:,[0, 1], [1, 0]] = -adjoint[:,:,[0, 1], [1, 0]] # [grow, gcol, 2, 2] + inv_B = (adjoint / reshaped_det).transpose(2, 3, 0, 1) # [2, 2, grow, gcol] + + v_minus_qstar_mul_mu = (reshaped_v - qstar) * reshaped_mu # [2, grow, gcol] + + # Get final image transfomer -- 3-D array + reshaped_v_minus_qstar_mul_mu = v_minus_qstar_mul_mu.reshape(1, 2, grow, gcol) # [1, 2, grow, gcol] + transformers = np.matmul(reshaped_v_minus_qstar_mul_mu.transpose(2, 3, 0, 1), + inv_B).reshape(grow, gcol, 2).transpose(2, 0, 1) + pstar # [2, grow, gcol] + + # Correct the points where pTwp is singular + if flag: + blidx = det == np.inf # bool index + transformers[0][blidx] = vx[blidx] + qstar[0][blidx] - pstar[0][blidx] + transformers[1][blidx] = vy[blidx] + qstar[1][blidx] - pstar[1][blidx] + + # Removed the points outside the border + transformers[transformers < 0] = 0 + transformers[0][transformers[0] > height - 1] = 0 + transformers[1][transformers[1] > width - 1] = 0 + + # Mapping original image + transformed_image = image[tuple(transformers.astype(np.int16))] # [grow, gcol] + + # Rescale image + # transformed_image = rescale(transformed_image, scale=1.0 / density, mode='reflect') + + return transformers, transformed_image + + +def mls_rigid_deformation_inv(image, height, width, channel, p, q, alpha=1.0, density=1.0): + ''' Rigid inverse deformation + ### Params: + * image - ndarray: original image + * p - ndarray: an array with size [n, 2], original control points + * q - ndarray: an array with size [n, 2], final control points + * alpha - float: parameter used by weights + * density - float: density of the grids + ### Return: + A deformed image. + ''' + height = image.shape[0] + width = image.shape[1] + # Change (x, y) to (row, col) + q = q[:, [1, 0]] + p = p[:, [1, 0]] + + # Make grids on the original image + gridX = np.linspace(0, width, num=int(width*density), endpoint=False) + gridY = np.linspace(0, height, num=int(height*density), endpoint=False) + vy, vx = np.meshgrid(gridX, gridY) + grow = vx.shape[0] # grid rows + gcol = vx.shape[1] # grid cols + ctrls = p.shape[0] # control points + + # Compute + reshaped_p = p.reshape(ctrls, 2, 1, 1) # [ctrls, 2, 1, 1] + reshaped_q = q.reshape((ctrls, 2, 1, 1)) # [ctrls, 2, 1, 1] + reshaped_v = np.vstack((vx.reshape(1, grow, gcol), vy.reshape(1, grow, gcol))) # [2, grow, gcol] + + w = 1.0 / np.sum((reshaped_p - reshaped_v) ** 2, axis=1)**alpha # [ctrls, grow, gcol] + w[w == np.inf] = 2**31 - 1 + pstar = np.sum(w * reshaped_p.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol] + phat = reshaped_p - pstar # [ctrls, 2, grow, gcol] + qstar = np.sum(w * reshaped_q.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol] + qhat = reshaped_q - qstar # [ctrls, 2, grow, gcol] + reshaped_phat1 = phat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] + reshaped_phat2 = phat.reshape(ctrls, 2, 1, grow, gcol) # [ctrls, 2, 1, grow, gcol] + reshaped_qhat = qhat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] + reshaped_w = w.reshape(ctrls, 1, 1, grow, gcol) # [ctrls, 1, 1, grow, gcol] + + mu = np.sum(np.matmul(reshaped_w.transpose(0, 3, 4, 1, 2) * + reshaped_phat1.transpose(0, 3, 4, 1, 2), + reshaped_phat2.transpose(0, 3, 4, 1, 2)), axis=0) # [grow, gcol, 1, 1] + reshaped_mu = mu.reshape(1, grow, gcol) # [1, grow, gcol] + neg_phat_verti = phat[:, [1, 0],...] # [ctrls, 2, grow, gcol] + neg_phat_verti[:, 1,...] = -neg_phat_verti[:, 1,...] + reshaped_neg_phat_verti = neg_phat_verti.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] + mul_right = np.concatenate((reshaped_phat1, reshaped_neg_phat_verti), axis=1) # [ctrls, 2, 2, grow, gcol] + mul_left = reshaped_qhat * reshaped_w # [ctrls, 1, 2, grow, gcol] + Delta = np.sum(np.matmul(mul_left.transpose(0, 3, 4, 1, 2), + mul_right.transpose(0, 3, 4, 1, 2)), + axis=0).transpose(0, 1, 3, 2) # [grow, gcol, 2, 1] + Delta_verti = Delta[...,[1, 0],:] # [grow, gcol, 2, 1] + Delta_verti[...,0,:] = -Delta_verti[...,0,:] + B = np.concatenate((Delta, Delta_verti), axis=3) # [grow, gcol, 2, 2] + try: + inv_B = np.linalg.inv(B) # [grow, gcol, 2, 2] + flag = False + except np.linalg.linalg.LinAlgError: + flag = True + det = np.linalg.det(B) # [grow, gcol] + det[det < 1e-8] = np.inf + reshaped_det = det.reshape(grow, gcol, 1, 1) # [grow, gcol, 1, 1] + adjoint = B[:,:,[[1, 0], [1, 0]], [[1, 1], [0, 0]]] # [grow, gcol, 2, 2] + adjoint[:,:,[0, 1], [1, 0]] = -adjoint[:,:,[0, 1], [1, 0]] # [grow, gcol, 2, 2] + inv_B = (adjoint / reshaped_det).transpose(2, 3, 0, 1) # [2, 2, grow, gcol] + + vqstar = reshaped_v - qstar # [2, grow, gcol] + reshaped_vqstar = vqstar.reshape(1, 2, grow, gcol) # [1, 2, grow, gcol] + + # Get final image transfomer -- 3-D array + temp = np.matmul(reshaped_vqstar.transpose(2, 3, 0, 1), + inv_B).reshape(grow, gcol, 2).transpose(2, 0, 1) # [2, grow, gcol] + norm_temp = np.linalg.norm(temp, axis=0, keepdims=True) # [1, grow, gcol] + norm_vqstar = np.linalg.norm(vqstar, axis=0, keepdims=True) # [1, grow, gcol] + transformers = temp / norm_temp * norm_vqstar + pstar # [2, grow, gcol] + + # Correct the points where pTwp is singular + if flag: + blidx = det == np.inf # bool index + transformers[0][blidx] = vx[blidx] + qstar[0][blidx] - pstar[0][blidx] + transformers[1][blidx] = vy[blidx] + qstar[1][blidx] - pstar[1][blidx] + + # Removed the points outside the border + transformers[transformers < 0] = 0 + transformers[0][transformers[0] > height - 1] = 0 + transformers[1][transformers[1] > width - 1] = 0 + + # Mapping original image + transformed_image = image[tuple(transformers.astype(np.int16))] # [grow, gcol] + + # Rescale image + # transformed_image = rescale(transformed_image, scale=1.0 / density, mode='reflect') + + return transformers, transformed_image + +# mls rigid algorithm +def mls_rigid_deformation_inv_wy(image, height, width, channel, p, q, alpha=1.0, density=1.0): + ''' + Rigid inverse deformation + ### Params: + * image - ndarray: original image + * p - ndarray: an array with size [n, 2], original control points + * q - ndarray: an array with size [n, 2], final control points + * alpha - float: parameter used by weights + * density - float: density of the grids + ### Return: + A deformed image. + ''' + # Change (x, y) to (row, col) + q = q[:, [1, 0]] + p = p[:, [1, 0]] + + # Make grids on the original image + gridX = torch.linspace(0, width, steps=int(width * density)) + gridY = torch.linspace(0, height, steps=int(height * density)) + vx, vy = torch.meshgrid(gridY, gridX) + + grow = vx.shape[0] # grid rows + gcol = vx.shape[1] # grid cols + ctrls = p.shape[0] # control points + + # Compute + reshaped_p = p.reshape(ctrls, 2, 1, 1) # [ctrls, 2, 1, 1] + reshaped_q = q.reshape((ctrls, 2, 1, 1)) # [ctrls, 2, 1, 1] + reshaped_v = torch.stack((vx, vy), dim=0) # [2, grow, gcol] + + w = 1.0 / torch.sum((reshaped_p - reshaped_v) ** 2, dim=1) ** alpha # [ctrls, grow, gcol] + w[w == torch.tensor(float("inf"))] = 2 ** 31 - 1 + pstar = torch.sum(w * reshaped_p.permute(1, 0, 2, 3), dim=1) / torch.sum(w, dim=0) # [2, grow, gcol] + phat = reshaped_p - pstar # [ctrls, 2, grow, gcol] + qstar = torch.sum(w * reshaped_q.permute(1, 0, 2, 3), dim=1) / torch.sum(w, dim=0) # [2, grow, gcol] + qhat = reshaped_q - qstar # [ctrls, 2, grow, gcol] + reshaped_phat1 = phat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] + reshaped_qhat = qhat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] + reshaped_w = w.reshape(ctrls, 1, 1, grow, gcol) # [ctrls, 1, 1, grow, gcol] + + neg_phat_verti = phat[:, [1, 0], ...] # [ctrls, 2, grow, gcol] + neg_phat_verti[:, 1, ...] = -neg_phat_verti[:, 1, ...] + reshaped_neg_phat_verti = neg_phat_verti.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] + mul_right = torch.cat((reshaped_phat1, reshaped_neg_phat_verti), dim=1) # [ctrls, 2, 2, grow, gcol] + mul_left = reshaped_qhat * reshaped_w # [ctrls, 1, 2, grow, gcol] + Delta = torch.sum(torch.matmul(mul_left.permute(0, 3, 4, 1, 2), + mul_right.permute(0, 3, 4, 1, 2)), + dim=0).permute(0, 1, 3, 2) # [grow, gcol, 2, 1] + Delta_verti = Delta[..., [1, 0], :] # [grow, gcol, 2, 1] + Delta_verti[..., 0, :] = -Delta_verti[..., 0, :] + B = torch.cat((Delta, Delta_verti), dim=3) # [grow, gcol, 2, 2] + try: + inv_B = torch.inverse(B) # [grow, gcol, 2, 2] + flag = False + except: + flag = True + det = np.linalg.det(B.numpy()) + det = torch.from_numpy(det) # [grow, gcol] + det[det < 1e-8] = torch.tensor(float("inf")) + reshaped_det = det.reshape(grow, gcol, 1, 1) # [grow, gcol, 1, 1] + adjoint = B[:, :, [[1, 0], [1, 0]], [[1, 1], [0, 0]]] # [grow, gcol, 2, 2] + adjoint[:, :, [0, 1], [1, 0]] = -adjoint[:, :, [0, 1], [1, 0]] # [grow, gcol, 2, 2] + inv_B = (adjoint / reshaped_det).permute(2, 3, 0, 1) # [2, 2, grow, gcol] + + vqstar = reshaped_v - qstar # [2, grow, gcol] + reshaped_vqstar = vqstar.reshape(1, 2, grow, gcol) # [1, 2, grow, gcol] + + # Get final image transfomer -- 3-D array + temp = torch.matmul(reshaped_vqstar.permute(2, 3, 0, 1), + inv_B).reshape(grow, gcol, 2).permute(2, 0, 1) # [2, grow, gcol] + norm_temp = torch.norm(temp, dim=0, keepdim=True) # [1, grow, gcol] + norm_vqstar = torch.norm(vqstar, dim=0, keepdim=True) # [1, grow, gcol] + transformers = temp / (norm_temp + 1e-10) * norm_vqstar + pstar # [2, grow, gcol] + + # Correct the points where pTwp is singular + if flag: + blidx = det == torch.tensor(float("inf")) # bool index + transformers[0][blidx] = vx[blidx] + qstar[0][blidx] - pstar[0][blidx] + transformers[1][blidx] = vy[blidx] + qstar[1][blidx] - pstar[1][blidx] + + # Removed the points outside the border + transformers[transformers < 0] = 0 + transformers[0][transformers[0] > height - 1] = 0 + transformers[1][transformers[1] > width - 1] = 0 + + # Mapping original image + + transformed_image = image[tuple(transformers.numpy().astype(np.int16))] # [grow, gcol] + + # # Rescale image + # img_h, img_w, _ = transformed_image.shape + # transformed_image = transformed_image.resize_(int(img_h/density), int(img_w/density), channel) + + return transformers.numpy(), transformed_image + + + + + + + + +# mls for whole feature, instead of roi align +def roi_mls_whole(feature, d_point, g_point, step=1): + ''' + :param feature: itorchut guidance feature [C, H, W] + :param d_point: landmark for degraded feature [N, 2] + :param g_point: landmark for guidance feature [N, 2] + :param step: step of landmark choose, number of control points: landmark_number/step + :return: transformed feature [C, H, W] + ''' + # feature 3 * 256 * 256 + + channel = feature.size(0) + height = feature.size(1) # 256 * 256 + width = feature.size(2) + + # ignore the boarder point of face + g_land = g_point[0::step, :] + d_land = d_point[0::step, :] + + # mls + + featureTmp = feature.permute(1,2,0) + # grid, timg = mls_rigid_deformation_inv_wy(featureTmp.cpu(), height, width, channel, g_land.cpu(), d_land.cpu(), density=1.) # 2 * 256 * 256 # wenyu + grid, timg = mls_affine_deformation_inv(featureTmp.cpu(), height, width, channel, g_land.cpu().numpy(), d_land.cpu().numpy(), density=1.) #affine prefered + # grid, timg_sim = mls_similarity_deformation_inv(featureTmp.cpu(), height, width, channel, g_land.cpu().numpy(), d_land.cpu().numpy(), density=1.) # similarity + # grid, timg_rigid = mls_rigid_deformation_inv(featureTmp.cpu(), height, width, channel, g_land.cpu().numpy(), d_land.cpu().numpy(), density=1.) #rigid + + + grid = (grid - 127.5) / 127.5 + gridNew = torch.from_numpy(grid[[1,0],:,:]).float().permute(1,2,0).unsqueeze(0) + + if torch.cuda.is_available(): + gridNew = gridNew.cuda() + featureNew = feature.unsqueeze(0) + # warp_feature = F.grid_sample(featureNew,gridNew.cuda(),mode='nearest') + warp_feature = F.grid_sample(featureNew,gridNew.cuda()) + + # HWC -> CHW + + + # _, timg_affine = mls_affine_deformation_inv(featureTmp.cpu(), height, width, channel, g_land.cpu().numpy(), d_land.cpu().numpy(), density=1.) + # _, timg_sim = mls_similarity_deformation_inv(featureTmp.cpu(), height, width, channel, g_land.cpu().numpy(), d_land.cpu().numpy(), density=1.) + # _, timg_rigid = mls_rigid_deformation_inv(featureTmp.cpu(), height, width, channel, g_land.cpu().numpy(), d_land.cpu().numpy(), density=1.) + + # return warp_feature.squeeze(), timg.permute(2,0,1) #, timg_sim.permute(2,0,1), timg_rigid.permute(2,0,1) + return warp_feature.squeeze(), gridNew.cuda() + + +def roi_mls_whole_final(feature, d_point, g_point, step=1): + ''' + :param feature: itorchut guidance feature [C, H, W] + :param d_point: landmark for degraded feature [N, 2] + :param g_point: landmark for guidance feature [N, 2] + :param step: step of landmark choose, number of control points: landmark_number/step + :return: transformed feature [C, H, W] + ''' + # feature 3 * 256 * 256 + + channel = feature.size(0) + height = feature.size(1) # 256 * 256 + width = feature.size(2) + + # ignore the boarder point of face + g_land = g_point[0::step, :] + d_land = d_point[0::step, :] + + # mls + + # featureTmp = feature.permute(1,2,0) + # grid, timg = mls_rigid_deformation_inv_wy(featureTmp.cpu(), height, width, channel, g_land.cpu(), d_land.cpu(), density=1.) # 2 * 256 * 256 # wenyu + grid = mls_affine_deformation_inv_final(height, width, channel, g_land.cpu().numpy(), d_land.cpu().numpy(), density=1.) #affine prefered + # grid, timg_sim = mls_similarity_deformation_inv(featureTmp.cpu(), height, width, channel, g_land.cpu().numpy(), d_land.cpu().numpy(), density=1.) # similarity + # grid, timg_rigid = mls_rigid_deformation_inv(featureTmp.cpu(), height, width, channel, g_land.cpu().numpy(), d_land.cpu().numpy(), density=1.) #rigid + + grid = (grid - height/2) / (height/2) + + gridNew = torch.from_numpy(grid[[1,0],:,:]).float().permute(1,2,0).unsqueeze(0) + + if torch.cuda.is_available(): + gridNew = gridNew.cuda() + # HWC -> CHW + return gridNew \ No newline at end of file diff --git a/util/get_data.py b/util/get_data.py new file mode 100644 index 0000000000000000000000000000000000000000..6325605bc68ec3b4036a4e0f42c28e0b8965867d --- /dev/null +++ b/util/get_data.py @@ -0,0 +1,115 @@ +from __future__ import print_function +import os +import tarfile +import requests +from warnings import warn +from zipfile import ZipFile +from bs4 import BeautifulSoup +from os.path import abspath, isdir, join, basename + + +class GetData(object): + """ + + Download CycleGAN or Pix2Pix Data. + + Args: + technique : str + One of: 'cyclegan' or 'pix2pix'. + verbose : bool + If True, print additional information. + + Examples: + >>> from util.get_data import GetData + >>> gd = GetData(technique='cyclegan') + >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. + + """ + + def __init__(self, technique='cyclegan', verbose=True): + url_dict = { + 'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets', + 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' + } + self.url = url_dict.get(technique.lower()) + self._verbose = verbose + + def _print(self, text): + if self._verbose: + print(text) + + @staticmethod + def _get_options(r): + soup = BeautifulSoup(r.text, 'lxml') + options = [h.text for h in soup.find_all('a', href=True) + if h.text.endswith(('.zip', 'tar.gz'))] + return options + + def _present_options(self): + r = requests.get(self.url) + options = self._get_options(r) + print('Options:\n') + for i, o in enumerate(options): + print("{0}: {1}".format(i, o)) + choice = input("\nPlease enter the number of the " + "dataset above you wish to download:") + return options[int(choice)] + + def _download_data(self, dataset_url, save_path): + if not isdir(save_path): + os.makedirs(save_path) + + base = basename(dataset_url) + temp_save_path = join(save_path, base) + + with open(temp_save_path, "wb") as f: + r = requests.get(dataset_url) + f.write(r.content) + + if base.endswith('.tar.gz'): + obj = tarfile.open(temp_save_path) + elif base.endswith('.zip'): + obj = ZipFile(temp_save_path, 'r') + else: + raise ValueError("Unknown File Type: {0}.".format(base)) + + self._print("Unpacking Data...") + obj.extractall(save_path) + obj.close() + os.remove(temp_save_path) + + def get(self, save_path, dataset=None): + """ + + Download a dataset. + + Args: + save_path : str + A directory to save the data to. + dataset : str, optional + A specific dataset to download. + Note: this must include the file extension. + If None, options will be presented for you + to choose from. + + Returns: + save_path_full : str + The absolute path to the downloaded data. + + """ + if dataset is None: + selected_dataset = self._present_options() + else: + selected_dataset = dataset + + save_path_full = join(save_path, selected_dataset.split('.')[0]) + + if isdir(save_path_full): + warn("\n'{0}' already exists. Voiding Download.".format( + save_path_full)) + else: + self._print('Downloading Data...') + url = "{0}/{1}".format(self.url, selected_dataset) + self._download_data(url, save_path=save_path) + + return abspath(save_path_full) diff --git a/util/html.py b/util/html.py new file mode 100644 index 0000000000000000000000000000000000000000..c7956f1353fd25aee253e39a6178481b0b330621 --- /dev/null +++ b/util/html.py @@ -0,0 +1,64 @@ +import dominate +from dominate.tags import * +import os + + +class HTML: + def __init__(self, web_dir, title, reflesh=0): + self.title = title + self.web_dir = web_dir + self.img_dir = os.path.join(self.web_dir, 'images') + if not os.path.exists(self.web_dir): + os.makedirs(self.web_dir) + if not os.path.exists(self.img_dir): + os.makedirs(self.img_dir) + # print(self.img_dir) + + self.doc = dominate.document(title=title) + if reflesh > 0: + with self.doc.head: + meta(http_equiv="reflesh", content=str(reflesh)) + + def get_image_dir(self): + return self.img_dir + + def add_header(self, str): + with self.doc: + h3(str) + + def add_table(self, border=1): + self.t = table(border=border, style="table-layout: fixed;") + self.doc.add(self.t) + + def add_images(self, ims, txts, links, width=400): + self.add_table() + with self.t: + with tr(): + for im, txt, link in zip(ims, txts, links): + with td(style="word-wrap: break-word;", halign="center", valign="top"): + with p(): + with a(href=os.path.join('images', link)): + img(style="width:%dpx" % width, src=os.path.join('images', im)) + br() + p(txt) + + def save(self): + html_file = '%s/index.html' % self.web_dir + f = open(html_file, 'wt') + f.write(self.doc.render()) + f.close() + + +if __name__ == '__main__': + html = HTML('web/', 'test_html') + html.add_header('hello world') + + ims = [] + txts = [] + links = [] + for n in range(4): + ims.append('image_%d.png' % n) + txts.append('text_%d' % n) + links.append('image_%d.png' % n) + html.add_images(ims, txts, links) + html.save() diff --git a/util/image_pool.py b/util/image_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..52413e0f8a45a8c8511bf103d3aabd537fac97b9 --- /dev/null +++ b/util/image_pool.py @@ -0,0 +1,32 @@ +import random +import torch + + +class ImagePool(): + def __init__(self, pool_size): + self.pool_size = pool_size + if self.pool_size > 0: + self.num_imgs = 0 + self.images = [] + + def query(self, images): + if self.pool_size == 0: + return images + return_images = [] + for image in images: + image = torch.unsqueeze(image.data, 0) + if self.num_imgs < self.pool_size: + self.num_imgs = self.num_imgs + 1 + self.images.append(image) + return_images.append(image) + else: + p = random.uniform(0, 1) + if p > 0.5: + random_id = random.randint(0, self.pool_size - 1) # randint is inclusive + tmp = self.images[random_id].clone() + self.images[random_id] = image + return_images.append(tmp) + else: + return_images.append(image) + return_images = torch.cat(return_images, 0) + return return_images diff --git a/util/util.py b/util/util.py new file mode 100644 index 0000000000000000000000000000000000000000..7206cea2d9188bc292a5f640096c0a01ec682b42 --- /dev/null +++ b/util/util.py @@ -0,0 +1,140 @@ +# -- coding: utf-8 -- +from __future__ import print_function +import torch +import numpy as np +from PIL import Image +import os +import torchvision + +# Converts a Tensor into an image array (numpy) +# |imtype|: the desired type of the converted numpy array +def tensor2im(input_image, norm=1, imtype=np.uint8): + if isinstance(input_image, torch.Tensor): + image_tensor = input_image.data + else: + return input_image + if norm == 1: #for clamp -1 to 1 + image_numpy = image_tensor[0].cpu().float().clamp_(-1,1).numpy() + elif norm == 2: # for norm through max-min + image_ = image_tensor[0].cpu().float() + max_ = torch.max(image_) + min_ = torch.min(image_) + image_numpy = (image_ - min_)/(max_-min_)*2-1 + image_numpy = image_numpy.numpy() + else: + pass + if image_numpy.shape[0] == 1: + image_numpy = np.tile(image_numpy, (3, 1, 1)) + # print(image_numpy.shape) + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + # print(image_numpy.shape) + return image_numpy.astype(imtype) +def tensor2im3Channels(input_image, imtype=np.uint8): + if isinstance(input_image, torch.Tensor): + image_tensor = input_image.data + else: + return input_image + + image_numpy = image_tensor.cpu().float().clamp_(-1,1).numpy() + + # print(image_numpy.shape) + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + # print(image_numpy.shape) + return image_numpy.astype(imtype) + +def diagnose_network(net, name='network'): + mean = 0.0 + count = 0 + for param in net.parameters(): + if param.grad is not None: + mean += torch.mean(torch.abs(param.grad.data)) + count += 1 + if count > 0: + mean = mean / count + print(name) + print(mean) + + + + +def print_numpy(x, val=True, shp=False): + x = x.astype(np.float64) + if shp: + print('shape,', x.shape) + if val: + x = x.flatten() + print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( + np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) + + +def mkdirs(paths): + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def print_current_losses(epoch, i, losses, t, t_data): + message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data) + for k, v in losses.items(): + message += '%s: %.3f ' % (k, v) + + print(message) + # with open('', "a") as log_file: + # log_file.write('%s\n' % message) + +def display_current_results(writer,visuals,losses,step,save_result): + for label, images in visuals.items(): + if 'Mask' in label:# or 'Scale' in label: + grid = torchvision.utils.make_grid(images,normalize=False, scale_each=True) + # pass + else: + pass + grid = torchvision.utils.make_grid(images,normalize=True, scale_each=True) + writer.add_image(label,grid,step) + for k,v in losses.items(): + writer.add_scalar(k,v,step) + +def VisualFeature(input_feature, imtype=np.uint8): + if isinstance(input_feature, torch.Tensor): + image_tensor = input_feature.data + else: + return input_feature + + image_ = image_tensor.cpu().float() + + if image_.size(1) == 3: + image_ = image_.permute(1,2,0) + + # assert(image_.size(1) == 1) + + + + #####norm 0 to 1 + max_ = torch.max(image_) + min_ = torch.min(image_) + image_numpy = (image_ - min_)/(max_-min_)*2-1 + image_numpy = image_numpy.numpy() + image_numpy = (image_numpy + 1) / 2.0 * 255.0 + #####no norm + # print((max_,min_)) + # image_numpy = image_.numpy() + # image_numpy = image_numpy*255.0 + + + # print('wwwwwwwwwwwwww') + # print(max_) + # print(min_) + # print(image_numpy.shape) + return image_numpy.astype(imtype) + + +def save_image(image_numpy, image_path): + image_pil = Image.fromarray(image_numpy) + image_pil.save(image_path) diff --git a/util/visualizer.py b/util/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..77962a8750eb943e2d5c3d0052a02586d7115e08 --- /dev/null +++ b/util/visualizer.py @@ -0,0 +1,242 @@ +import numpy as np +import os +import ntpath +import time +from . import util +from . import html +# from scipy.misc import imresize + + +# save image to the disk +def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): + image_dir = webpage.get_image_dir() + if isinstance(image_path,list): + image_path = image_path[0] + short_path = ntpath.basename(image_path) + + name = os.path.splitext(short_path)[0] + + webpage.add_header(name) + ims, txts, links = [], [], [] + + for label, im_data in visuals.items(): + im = util.tensor2im(im_data) + # image_name = '%s_%s.png' % (name, label) + # save_path = os.path.join(image_dir, image_name) + image_name = '%s.png' % (name) + + util.mkdirs(os.path.join(image_dir, label)) + save_path = os.path.join(image_dir, label, image_name) + + h, w, _ = im.shape +# if aspect_ratio > 1.0: +# im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic') +# if aspect_ratio < 1.0: +# im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic') + + util.save_image(im, save_path) + + link_name = os.path.join(label,image_name) + ims.append(link_name) + txts.append(label) + links.append(link_name) + webpage.add_images(ims, txts, links, width=width) + +def save_crop(visuals, save_path): + im_data = visuals['fake_A'] + im = util.tensor2im(im_data) + util.save_image(im, save_path) + +# save image to the disk +def save_images_test(webpage, visuals, image_dir, image_path, aspect_ratio=1.0, width=256): + # image_dir = webpage.get_image_dir() + if isinstance(image_path,list): + image_path = image_path[0] + short_path = ntpath.basename(image_path) + + name = os.path.splitext(short_path)[0] + + webpage.add_header(name) + ims, txts, links = [], [], [] + + for label, im_data in visuals.items(): + im = util.tensor2im(im_data) + # image_name = '%s_%s.png' % (name, label) + # save_path = os.path.join(image_dir, image_name) + image_name = '%s.png' % (name) + + util.mkdirs(os.path.join(image_dir, label)) + save_path = os.path.join(image_dir, label, image_name) + + exit('wwwwwwwwwwwww') + + + h, w, _ = im.shape +# if aspect_ratio > 1.0: +# im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic') +# if aspect_ratio < 1.0: +# im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic') + + util.save_image(im, save_path) + + link_name = os.path.join(label,image_name) + ims.append(link_name) + txts.append(label) + links.append(link_name) + webpage.add_images(ims, txts, links, width=width) + +# save image to the disk +# Original Version +# def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): +# image_dir = webpage.get_image_dir() +# short_path = ntpath.basename(image_path[0]) +# name = os.path.splitext(short_path)[0] + +# webpage.add_header(name) +# ims, txts, links = [], [], [] + +# for label, im_data in visuals.items(): +# im = util.tensor2im(im_data) +# image_name = '%s_%s.png' % (name, label) +# save_path = os.path.join(image_dir, image_name) +# h, w, _ = im.shape +# if aspect_ratio > 1.0: +# im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic') +# if aspect_ratio < 1.0: +# im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic') +# util.save_image(im, save_path) + +# ims.append(image_name) +# txts.append(label) +# links.append(image_name) +# webpage.add_images(ims, txts, links, width=width) + + +class Visualizer(): + def __init__(self, opt): + self.display_id = opt.display_id + self.use_html = opt.isTrain and not opt.no_html + self.win_size = opt.display_winsize + self.name = opt.name + self.opt = opt + self.saved = False + if self.display_id > 0: + import visdom + self.ncols = opt.display_ncols + self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env, raise_exceptions=True) + + if self.use_html: + self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') + self.img_dir = os.path.join(self.web_dir, 'images') + print('create web directory %s...' % self.web_dir) + util.mkdirs([self.web_dir, self.img_dir]) + self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + def reset(self): + self.saved = False + + def throw_visdom_connection_error(self): + print('\n\nCould not connect to Visdom server (https://github.com/facebookresearch/visdom) for displaying training progress.\nYou can suppress connection to Visdom using the option --display_id -1. To install visdom, run \n$ pip install visdom\n, and start the server by \n$ python -m visdom.server.\n\n') + exit(1) + + # |visuals|: dictionary of images to display or save + def display_current_results(self, visuals, epoch, save_result): + if self.display_id > 0: # show images in the browser + ncols = self.ncols + if ncols > 0: + ncols = min(ncols, len(visuals)) + h, w = next(iter(visuals.values())).shape[:2] + table_css = """""" % (w, h) + title = self.name + label_html = '' + label_html_row = '' + images = [] + idx = 0 + for label, image in visuals.items(): + image_numpy = util.tensor2im(image) + label_html_row += '%s' % label + images.append(image_numpy.transpose([2, 0, 1])) + idx += 1 + if idx % ncols == 0: + label_html += '%s' % label_html_row + label_html_row = '' + white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 + while idx % ncols != 0: + images.append(white_image) + label_html_row += '' + idx += 1 + if label_html_row != '': + label_html += '%s' % label_html_row + # pane col = image row + try: + self.vis.images(images, nrow=ncols, win=self.display_id + 1, + padding=2, opts=dict(title=title + ' images')) + label_html = '%s
' % label_html + self.vis.text(table_css + label_html, win=self.display_id + 2, + opts=dict(title=title + ' labels')) + except ConnectionError: + self.throw_visdom_connection_error() + + else: + idx = 1 + for label, image in visuals.items(): + image_numpy = util.tensor2im(image) + self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), + win=self.display_id + idx) + idx += 1 + + if self.use_html and (save_result or not self.saved): # save images to a html file + self.saved = True + for label, image in visuals.items(): + image_numpy = util.tensor2im(image) + img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) + util.save_image(image_numpy, img_path) + # update website + webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) + for n in range(epoch, 0, -1): + webpage.add_header('epoch [%d]' % n) + ims, txts, links = [], [], [] + + for label, image_numpy in visuals.items(): + image_numpy = util.tensor2im(image) + img_path = 'epoch%.3d_%s.png' % (n, label) + ims.append(img_path) + txts.append(label) + links.append(img_path) + webpage.add_images(ims, txts, links, width=self.win_size) + webpage.save() + + # losses: dictionary of error labels and values + def plot_current_losses(self, epoch, counter_ratio, opt, losses): + if not hasattr(self, 'plot_data'): + self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} + self.plot_data['X'].append(epoch + counter_ratio) + self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) + try: + self.vis.line( + X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), + Y=np.array(self.plot_data['Y']), + opts={ + 'title': self.name + ' loss over time', + 'legend': self.plot_data['legend'], + 'xlabel': 'epoch', + 'ylabel': 'loss'}, + win=self.display_id) + except ConnectionError: + self.throw_visdom_connection_error() + + # losses: same format as |losses| of plot_current_losses + def print_current_losses(self, epoch, i, losses, t, t_data): + message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data) + for k, v in losses.items(): + message += '%s: %.3f ' % (k, v) + + print(message) + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message)