Spaces:
Build error
Build error
| from argparse import Namespace | |
| import os | |
| from os.path import join as pjoin | |
| from typing import Optional | |
| import cv2 | |
| import torch | |
| from tools import ( | |
| parse_face, | |
| match_histogram, | |
| ) | |
| from utils.torch_helpers import make_image | |
| from utils.misc import stem | |
| def match_skin_histogram( | |
| imgs: torch.Tensor, | |
| sibling_img: torch.Tensor, | |
| spectral_sensitivity, | |
| im_sibling_dir: str, | |
| mask_dir: str, | |
| matched_hist_fn: Optional[str] = None, | |
| normalize=None, # normalize the range of the tensor | |
| ): | |
| """ | |
| Extract the skin of the input and sibling images. Create a new input image by matching | |
| its histogram to the sibling. | |
| """ | |
| # TODO: Currently only allows imgs of batch size 1 | |
| im_sibling_dir = os.path.abspath(im_sibling_dir) | |
| mask_dir = os.path.abspath(mask_dir) | |
| img_np = make_image(imgs)[0] | |
| sibling_np = make_image(sibling_img)[0][...,::-1] | |
| # save img, sibling | |
| os.makedirs(im_sibling_dir, exist_ok=True) | |
| im_name, sibling_name = 'input.png', 'sibling.png' | |
| cv2.imwrite(pjoin(im_sibling_dir, im_name), img_np) | |
| cv2.imwrite(pjoin(im_sibling_dir, sibling_name), sibling_np) | |
| # face parsing | |
| parse_face.main( | |
| Namespace(in_dir=im_sibling_dir, out_dir=mask_dir, include_hair=False) | |
| ) | |
| # match_histogram | |
| mh_args = match_histogram.parse_args( | |
| args=[ | |
| pjoin(im_sibling_dir, im_name), | |
| pjoin(im_sibling_dir, sibling_name), | |
| ], | |
| namespace=Namespace( | |
| out=matched_hist_fn if matched_hist_fn else pjoin(im_sibling_dir, "match_histogram.png"), | |
| src_mask=pjoin(mask_dir, im_name), | |
| ref_mask=pjoin(mask_dir, sibling_name), | |
| spectral_sensitivity=spectral_sensitivity, | |
| ) | |
| ) | |
| matched_np = match_histogram.main(mh_args) / 255.0 # [0, 1] | |
| matched = torch.FloatTensor(matched_np).permute(2, 0, 1)[None,...] #BCHW | |
| if normalize is not None: | |
| matched = normalize(matched) | |
| return matched | |