Spaces:
Running
on
L4
Running
on
L4
update.
Browse files- CodeFormer/.gitignore → .gitignore +5 -4
- CodeFormer/basicsr/utils/misc.py +25 -2
- CodeFormer/basicsr/version.py +2 -2
- CodeFormer/facelib/utils/face_restoration_helper.py +77 -12
- CodeFormer/facelib/utils/misc.py +32 -4
- CodeFormer/inference_codeformer.py +126 -41
- README.md +1 -1
- app.py +12 -10
CodeFormer/.gitignore → .gitignore
RENAMED
|
@@ -5,9 +5,9 @@ version.py
|
|
| 5 |
|
| 6 |
# ignored files with suffix
|
| 7 |
*.html
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
*.pt
|
| 12 |
*.gif
|
| 13 |
*.pth
|
|
@@ -122,7 +122,8 @@ venv.bak/
|
|
| 122 |
.mypy_cache/
|
| 123 |
|
| 124 |
# project
|
| 125 |
-
results/
|
|
|
|
| 126 |
dlib/
|
| 127 |
*.pth
|
| 128 |
*_old*
|
|
|
|
| 5 |
|
| 6 |
# ignored files with suffix
|
| 7 |
*.html
|
| 8 |
+
*.png
|
| 9 |
+
*.jpeg
|
| 10 |
+
*.jpg
|
| 11 |
*.pt
|
| 12 |
*.gif
|
| 13 |
*.pth
|
|
|
|
| 122 |
.mypy_cache/
|
| 123 |
|
| 124 |
# project
|
| 125 |
+
CodeFormer/results/
|
| 126 |
+
output/
|
| 127 |
dlib/
|
| 128 |
*.pth
|
| 129 |
*_old*
|
CodeFormer/basicsr/utils/misc.py
CHANGED
|
@@ -1,13 +1,36 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
import os
|
|
|
|
| 3 |
import random
|
| 4 |
import time
|
| 5 |
import torch
|
|
|
|
| 6 |
from os import path as osp
|
| 7 |
|
| 8 |
from .dist_util import master_only
|
| 9 |
from .logger import get_root_logger
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
def set_random_seed(seed):
|
| 13 |
"""Set random seeds."""
|
|
@@ -131,4 +154,4 @@ def sizeof_fmt(size, suffix='B'):
|
|
| 131 |
if abs(size) < 1024.0:
|
| 132 |
return f'{size:3.1f} {unit}{suffix}'
|
| 133 |
size /= 1024.0
|
| 134 |
-
return f'{size:3.1f} Y{suffix}'
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
+
import re
|
| 3 |
import random
|
| 4 |
import time
|
| 5 |
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
from os import path as osp
|
| 8 |
|
| 9 |
from .dist_util import master_only
|
| 10 |
from .logger import get_root_logger
|
| 11 |
|
| 12 |
+
IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
|
| 13 |
+
torch.__version__)[0][:3])] >= [1, 12, 0]
|
| 14 |
+
|
| 15 |
+
def gpu_is_available():
|
| 16 |
+
if IS_HIGH_VERSION:
|
| 17 |
+
if torch.backends.mps.is_available():
|
| 18 |
+
return True
|
| 19 |
+
return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
|
| 20 |
+
|
| 21 |
+
def get_device(gpu_id=None):
|
| 22 |
+
if gpu_id is None:
|
| 23 |
+
gpu_str = ''
|
| 24 |
+
elif isinstance(gpu_id, int):
|
| 25 |
+
gpu_str = f':{gpu_id}'
|
| 26 |
+
else:
|
| 27 |
+
raise TypeError('Input should be int value.')
|
| 28 |
+
|
| 29 |
+
if IS_HIGH_VERSION:
|
| 30 |
+
if torch.backends.mps.is_available():
|
| 31 |
+
return torch.device('mps'+gpu_str)
|
| 32 |
+
return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
|
| 33 |
+
|
| 34 |
|
| 35 |
def set_random_seed(seed):
|
| 36 |
"""Set random seeds."""
|
|
|
|
| 154 |
if abs(size) < 1024.0:
|
| 155 |
return f'{size:3.1f} {unit}{suffix}'
|
| 156 |
size /= 1024.0
|
| 157 |
+
return f'{size:3.1f} Y{suffix}'
|
CodeFormer/basicsr/version.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
# GENERATED VERSION FILE
|
| 2 |
-
# TIME:
|
| 3 |
__version__ = '1.3.2'
|
| 4 |
-
__gitsha__ = '
|
| 5 |
version_info = (1, 3, 2)
|
|
|
|
| 1 |
# GENERATED VERSION FILE
|
| 2 |
+
# TIME: Sat Sep 21 15:31:46 2024
|
| 3 |
__version__ = '1.3.2'
|
| 4 |
+
__gitsha__ = '1.3.2'
|
| 5 |
version_info = (1, 3, 2)
|
CodeFormer/facelib/utils/face_restoration_helper.py
CHANGED
|
@@ -6,8 +6,14 @@ from torchvision.transforms.functional import normalize
|
|
| 6 |
|
| 7 |
from facelib.detection import init_detection_model
|
| 8 |
from facelib.parsing import init_parsing_model
|
| 9 |
-
from facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray
|
|
|
|
|
|
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
def get_largest_face(det_faces, h, w):
|
| 13 |
|
|
@@ -64,8 +70,15 @@ class FaceRestoreHelper(object):
|
|
| 64 |
self.crop_ratio = crop_ratio # (h, w)
|
| 65 |
assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
|
| 66 |
self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
|
| 70 |
else:
|
| 71 |
# standard 5 landmarks for FFHQ faces with 512 x 512
|
|
@@ -77,7 +90,6 @@ class FaceRestoreHelper(object):
|
|
| 77 |
# self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
|
| 78 |
# [198.22603, 372.82502], [313.91018, 372.75659]])
|
| 79 |
|
| 80 |
-
|
| 81 |
self.face_template = self.face_template * (face_size / 512.0)
|
| 82 |
if self.crop_ratio[0] > 1:
|
| 83 |
self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
|
|
@@ -97,12 +109,16 @@ class FaceRestoreHelper(object):
|
|
| 97 |
self.pad_input_imgs = []
|
| 98 |
|
| 99 |
if device is None:
|
| 100 |
-
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
| 101 |
else:
|
| 102 |
self.device = device
|
| 103 |
|
| 104 |
# init face detection model
|
| 105 |
-
self.
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
# init face parsing model
|
| 108 |
self.use_parse = use_parse
|
|
@@ -125,7 +141,7 @@ class FaceRestoreHelper(object):
|
|
| 125 |
img = img[:, :, 0:3]
|
| 126 |
|
| 127 |
self.input_img = img
|
| 128 |
-
self.is_gray = is_gray(img, threshold=
|
| 129 |
if self.is_gray:
|
| 130 |
print('Grayscale input: True')
|
| 131 |
|
|
@@ -133,25 +149,72 @@ class FaceRestoreHelper(object):
|
|
| 133 |
f = 512.0/min(self.input_img.shape[:2])
|
| 134 |
self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
|
| 135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
def get_face_landmarks_5(self,
|
| 137 |
only_keep_largest=False,
|
| 138 |
only_center_face=False,
|
| 139 |
resize=None,
|
| 140 |
blur_ratio=0.01,
|
| 141 |
eye_dist_threshold=None):
|
|
|
|
|
|
|
|
|
|
| 142 |
if resize is None:
|
| 143 |
scale = 1
|
| 144 |
input_img = self.input_img
|
| 145 |
else:
|
| 146 |
h, w = self.input_img.shape[0:2]
|
| 147 |
scale = resize / min(h, w)
|
| 148 |
-
scale = max(1, scale) # always scale up
|
| 149 |
h, w = int(h * scale), int(w * scale)
|
| 150 |
interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
|
| 151 |
input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
|
| 152 |
|
| 153 |
with torch.no_grad():
|
| 154 |
-
bboxes = self.
|
| 155 |
|
| 156 |
if bboxes is None or bboxes.shape[0] == 0:
|
| 157 |
return 0
|
|
@@ -298,10 +361,12 @@ class FaceRestoreHelper(object):
|
|
| 298 |
torch.save(inverse_affine, save_path)
|
| 299 |
|
| 300 |
|
| 301 |
-
def add_restored_face(self,
|
| 302 |
if self.is_gray:
|
| 303 |
-
|
| 304 |
-
|
|
|
|
|
|
|
| 305 |
|
| 306 |
|
| 307 |
def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
|
|
|
|
| 6 |
|
| 7 |
from facelib.detection import init_detection_model
|
| 8 |
from facelib.parsing import init_parsing_model
|
| 9 |
+
from facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray, adain_npy
|
| 10 |
+
from basicsr.utils.download_util import load_file_from_url
|
| 11 |
+
from basicsr.utils.misc import get_device
|
| 12 |
|
| 13 |
+
dlib_model_url = {
|
| 14 |
+
'face_detector': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/mmod_human_face_detector-4cb19393.dat',
|
| 15 |
+
'shape_predictor_5': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/shape_predictor_5_face_landmarks-c4b1e980.dat'
|
| 16 |
+
}
|
| 17 |
|
| 18 |
def get_largest_face(det_faces, h, w):
|
| 19 |
|
|
|
|
| 70 |
self.crop_ratio = crop_ratio # (h, w)
|
| 71 |
assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
|
| 72 |
self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
|
| 73 |
+
self.det_model = det_model
|
| 74 |
+
|
| 75 |
+
if self.det_model == 'dlib':
|
| 76 |
+
# standard 5 landmarks for FFHQ faces with 1024 x 1024
|
| 77 |
+
self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941],
|
| 78 |
+
[337.91089109, 488.38613861], [437.95049505, 493.51485149],
|
| 79 |
+
[513.58415842, 678.5049505]])
|
| 80 |
+
self.face_template = self.face_template / (1024 // face_size)
|
| 81 |
+
elif self.template_3points:
|
| 82 |
self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
|
| 83 |
else:
|
| 84 |
# standard 5 landmarks for FFHQ faces with 512 x 512
|
|
|
|
| 90 |
# self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
|
| 91 |
# [198.22603, 372.82502], [313.91018, 372.75659]])
|
| 92 |
|
|
|
|
| 93 |
self.face_template = self.face_template * (face_size / 512.0)
|
| 94 |
if self.crop_ratio[0] > 1:
|
| 95 |
self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
|
|
|
|
| 109 |
self.pad_input_imgs = []
|
| 110 |
|
| 111 |
if device is None:
|
| 112 |
+
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 113 |
+
self.device = get_device()
|
| 114 |
else:
|
| 115 |
self.device = device
|
| 116 |
|
| 117 |
# init face detection model
|
| 118 |
+
if self.det_model == 'dlib':
|
| 119 |
+
self.face_detector, self.shape_predictor_5 = self.init_dlib(dlib_model_url['face_detector'], dlib_model_url['shape_predictor_5'])
|
| 120 |
+
else:
|
| 121 |
+
self.face_detector = init_detection_model(det_model, half=False, device=self.device)
|
| 122 |
|
| 123 |
# init face parsing model
|
| 124 |
self.use_parse = use_parse
|
|
|
|
| 141 |
img = img[:, :, 0:3]
|
| 142 |
|
| 143 |
self.input_img = img
|
| 144 |
+
self.is_gray = is_gray(img, threshold=10)
|
| 145 |
if self.is_gray:
|
| 146 |
print('Grayscale input: True')
|
| 147 |
|
|
|
|
| 149 |
f = 512.0/min(self.input_img.shape[:2])
|
| 150 |
self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
|
| 151 |
|
| 152 |
+
def init_dlib(self, detection_path, landmark5_path):
|
| 153 |
+
"""Initialize the dlib detectors and predictors."""
|
| 154 |
+
try:
|
| 155 |
+
import dlib
|
| 156 |
+
except ImportError:
|
| 157 |
+
print('Please install dlib by running:' 'conda install -c conda-forge dlib')
|
| 158 |
+
detection_path = load_file_from_url(url=detection_path, model_dir='weights/dlib', progress=True, file_name=None)
|
| 159 |
+
landmark5_path = load_file_from_url(url=landmark5_path, model_dir='weights/dlib', progress=True, file_name=None)
|
| 160 |
+
face_detector = dlib.cnn_face_detection_model_v1(detection_path)
|
| 161 |
+
shape_predictor_5 = dlib.shape_predictor(landmark5_path)
|
| 162 |
+
return face_detector, shape_predictor_5
|
| 163 |
+
|
| 164 |
+
def get_face_landmarks_5_dlib(self,
|
| 165 |
+
only_keep_largest=False,
|
| 166 |
+
scale=1):
|
| 167 |
+
det_faces = self.face_detector(self.input_img, scale)
|
| 168 |
+
|
| 169 |
+
if len(det_faces) == 0:
|
| 170 |
+
print('No face detected. Try to increase upsample_num_times.')
|
| 171 |
+
return 0
|
| 172 |
+
else:
|
| 173 |
+
if only_keep_largest:
|
| 174 |
+
print('Detect several faces and only keep the largest.')
|
| 175 |
+
face_areas = []
|
| 176 |
+
for i in range(len(det_faces)):
|
| 177 |
+
face_area = (det_faces[i].rect.right() - det_faces[i].rect.left()) * (
|
| 178 |
+
det_faces[i].rect.bottom() - det_faces[i].rect.top())
|
| 179 |
+
face_areas.append(face_area)
|
| 180 |
+
largest_idx = face_areas.index(max(face_areas))
|
| 181 |
+
self.det_faces = [det_faces[largest_idx]]
|
| 182 |
+
else:
|
| 183 |
+
self.det_faces = det_faces
|
| 184 |
+
|
| 185 |
+
if len(self.det_faces) == 0:
|
| 186 |
+
return 0
|
| 187 |
+
|
| 188 |
+
for face in self.det_faces:
|
| 189 |
+
shape = self.shape_predictor_5(self.input_img, face.rect)
|
| 190 |
+
landmark = np.array([[part.x, part.y] for part in shape.parts()])
|
| 191 |
+
self.all_landmarks_5.append(landmark)
|
| 192 |
+
|
| 193 |
+
return len(self.all_landmarks_5)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
def get_face_landmarks_5(self,
|
| 197 |
only_keep_largest=False,
|
| 198 |
only_center_face=False,
|
| 199 |
resize=None,
|
| 200 |
blur_ratio=0.01,
|
| 201 |
eye_dist_threshold=None):
|
| 202 |
+
if self.det_model == 'dlib':
|
| 203 |
+
return self.get_face_landmarks_5_dlib(only_keep_largest)
|
| 204 |
+
|
| 205 |
if resize is None:
|
| 206 |
scale = 1
|
| 207 |
input_img = self.input_img
|
| 208 |
else:
|
| 209 |
h, w = self.input_img.shape[0:2]
|
| 210 |
scale = resize / min(h, w)
|
| 211 |
+
# scale = max(1, scale) # always scale up; comment this out for HD images, e.g., AIGC faces.
|
| 212 |
h, w = int(h * scale), int(w * scale)
|
| 213 |
interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
|
| 214 |
input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
|
| 215 |
|
| 216 |
with torch.no_grad():
|
| 217 |
+
bboxes = self.face_detector.detect_faces(input_img)
|
| 218 |
|
| 219 |
if bboxes is None or bboxes.shape[0] == 0:
|
| 220 |
return 0
|
|
|
|
| 361 |
torch.save(inverse_affine, save_path)
|
| 362 |
|
| 363 |
|
| 364 |
+
def add_restored_face(self, restored_face, input_face=None):
|
| 365 |
if self.is_gray:
|
| 366 |
+
restored_face = bgr2gray(restored_face) # convert img into grayscale
|
| 367 |
+
if input_face is not None:
|
| 368 |
+
restored_face = adain_npy(restored_face, input_face) # transfer the color
|
| 369 |
+
self.restored_faces.append(restored_face)
|
| 370 |
|
| 371 |
|
| 372 |
def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
|
CodeFormer/facelib/utils/misc.py
CHANGED
|
@@ -7,13 +7,13 @@ import torch
|
|
| 7 |
from torch.hub import download_url_to_file, get_dir
|
| 8 |
from urllib.parse import urlparse
|
| 9 |
# from basicsr.utils.download_util import download_file_from_google_drive
|
| 10 |
-
# import gdown
|
| 11 |
-
|
| 12 |
|
| 13 |
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 14 |
|
| 15 |
|
| 16 |
def download_pretrained_models(file_ids, save_path_root):
|
|
|
|
|
|
|
| 17 |
os.makedirs(save_path_root, exist_ok=True)
|
| 18 |
|
| 19 |
for file_name, file_id in file_ids.items():
|
|
@@ -23,7 +23,7 @@ def download_pretrained_models(file_ids, save_path_root):
|
|
| 23 |
user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
|
| 24 |
if user_response.lower() == 'y':
|
| 25 |
print(f'Covering {file_name} to {save_path}')
|
| 26 |
-
|
| 27 |
# download_file_from_google_drive(file_id, save_path)
|
| 28 |
elif user_response.lower() == 'n':
|
| 29 |
print(f'Skipping {file_name}')
|
|
@@ -31,7 +31,7 @@ def download_pretrained_models(file_ids, save_path_root):
|
|
| 31 |
raise ValueError('Wrong input. Only accepts Y/N.')
|
| 32 |
else:
|
| 33 |
print(f'Downloading {file_name} to {save_path}')
|
| 34 |
-
|
| 35 |
# download_file_from_google_drive(file_id, save_path)
|
| 36 |
|
| 37 |
|
|
@@ -172,3 +172,31 @@ def bgr2gray(img, out_channel=3):
|
|
| 172 |
if out_channel == 3:
|
| 173 |
gray = gray[:,:,np.newaxis].repeat(3, axis=2)
|
| 174 |
return gray
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from torch.hub import download_url_to_file, get_dir
|
| 8 |
from urllib.parse import urlparse
|
| 9 |
# from basicsr.utils.download_util import download_file_from_google_drive
|
|
|
|
|
|
|
| 10 |
|
| 11 |
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 12 |
|
| 13 |
|
| 14 |
def download_pretrained_models(file_ids, save_path_root):
|
| 15 |
+
import gdown
|
| 16 |
+
|
| 17 |
os.makedirs(save_path_root, exist_ok=True)
|
| 18 |
|
| 19 |
for file_name, file_id in file_ids.items():
|
|
|
|
| 23 |
user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
|
| 24 |
if user_response.lower() == 'y':
|
| 25 |
print(f'Covering {file_name} to {save_path}')
|
| 26 |
+
gdown.download(file_url, save_path, quiet=False)
|
| 27 |
# download_file_from_google_drive(file_id, save_path)
|
| 28 |
elif user_response.lower() == 'n':
|
| 29 |
print(f'Skipping {file_name}')
|
|
|
|
| 31 |
raise ValueError('Wrong input. Only accepts Y/N.')
|
| 32 |
else:
|
| 33 |
print(f'Downloading {file_name} to {save_path}')
|
| 34 |
+
gdown.download(file_url, save_path, quiet=False)
|
| 35 |
# download_file_from_google_drive(file_id, save_path)
|
| 36 |
|
| 37 |
|
|
|
|
| 172 |
if out_channel == 3:
|
| 173 |
gray = gray[:,:,np.newaxis].repeat(3, axis=2)
|
| 174 |
return gray
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def calc_mean_std(feat, eps=1e-5):
|
| 178 |
+
"""
|
| 179 |
+
Args:
|
| 180 |
+
feat (numpy): 3D [w h c]s
|
| 181 |
+
"""
|
| 182 |
+
size = feat.shape
|
| 183 |
+
assert len(size) == 3, 'The input feature should be 3D tensor.'
|
| 184 |
+
c = size[2]
|
| 185 |
+
feat_var = feat.reshape(-1, c).var(axis=0) + eps
|
| 186 |
+
feat_std = np.sqrt(feat_var).reshape(1, 1, c)
|
| 187 |
+
feat_mean = feat.reshape(-1, c).mean(axis=0).reshape(1, 1, c)
|
| 188 |
+
return feat_mean, feat_std
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def adain_npy(content_feat, style_feat):
|
| 192 |
+
"""Adaptive instance normalization for numpy.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
content_feat (numpy): The input feature.
|
| 196 |
+
style_feat (numpy): The reference feature.
|
| 197 |
+
"""
|
| 198 |
+
size = content_feat.shape
|
| 199 |
+
style_mean, style_std = calc_mean_std(style_feat)
|
| 200 |
+
content_mean, content_std = calc_mean_std(content_feat)
|
| 201 |
+
normalized_feat = (content_feat - np.broadcast_to(content_mean, size)) / np.broadcast_to(content_std, size)
|
| 202 |
+
return normalized_feat * np.broadcast_to(style_std, size) + np.broadcast_to(style_mean, size)
|
CodeFormer/inference_codeformer.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
# Modified by Shangchen Zhou from: https://github.com/TencentARC/GFPGAN/blob/master/inference_gfpgan.py
|
| 2 |
import os
|
| 3 |
import cv2
|
| 4 |
import argparse
|
|
@@ -7,8 +6,9 @@ import torch
|
|
| 7 |
from torchvision.transforms.functional import normalize
|
| 8 |
from basicsr.utils import imwrite, img2tensor, tensor2img
|
| 9 |
from basicsr.utils.download_util import load_file_from_url
|
|
|
|
| 10 |
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
| 11 |
-
|
| 12 |
|
| 13 |
from basicsr.utils.registry import ARCH_REGISTRY
|
| 14 |
|
|
@@ -17,51 +17,104 @@ pretrain_model_url = {
|
|
| 17 |
}
|
| 18 |
|
| 19 |
def set_realesrgan():
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
import warnings
|
| 22 |
-
warnings.warn('
|
| 23 |
-
'
|
|
|
|
| 24 |
category=RuntimeWarning)
|
| 25 |
-
|
| 26 |
-
else:
|
| 27 |
-
from basicsr.archs.rrdbnet_arch import RRDBNet
|
| 28 |
-
from basicsr.utils.realesrgan_utils import RealESRGANer
|
| 29 |
-
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
| 30 |
-
bg_upsampler = RealESRGANer(
|
| 31 |
-
scale=2,
|
| 32 |
-
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
|
| 33 |
-
model=model,
|
| 34 |
-
tile=args.bg_tile,
|
| 35 |
-
tile_pad=40,
|
| 36 |
-
pre_pad=0,
|
| 37 |
-
half=True) # need to set False in CPU mode
|
| 38 |
-
return bg_upsampler
|
| 39 |
|
| 40 |
if __name__ == '__main__':
|
| 41 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
| 42 |
parser = argparse.ArgumentParser()
|
| 43 |
|
| 44 |
-
parser.add_argument('--
|
| 45 |
-
|
| 46 |
-
parser.add_argument('--
|
| 47 |
-
|
| 48 |
-
parser.add_argument('
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
# large det_model: 'YOLOv5l', 'retinaface_resnet50'
|
| 50 |
# small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
|
| 51 |
-
parser.add_argument('--detection_model', type=str, default='retinaface_resnet50'
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
parser.add_argument('--
|
|
|
|
| 55 |
parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler. Default: 400')
|
|
|
|
|
|
|
| 56 |
|
| 57 |
args = parser.parse_args()
|
| 58 |
|
| 59 |
# ------------------------ input & output ------------------------
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
| 65 |
|
| 66 |
# ------------------ set up background upsampler ------------------
|
| 67 |
if args.bg_upsampler == 'realesrgan':
|
|
@@ -109,19 +162,27 @@ if __name__ == '__main__':
|
|
| 109 |
device=device)
|
| 110 |
|
| 111 |
# -------------------- start to processing ---------------------
|
| 112 |
-
|
| 113 |
-
for img_path in sorted(glob.glob(os.path.join(args.test_path, '*.[jp][pn]g'))):
|
| 114 |
# clean all the intermediate results to process the next image
|
| 115 |
face_helper.clean_all()
|
| 116 |
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
if args.has_aligned:
|
| 123 |
# the input faces are already cropped and aligned
|
| 124 |
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
|
|
|
|
|
|
|
|
|
|
| 125 |
face_helper.cropped_faces = [img]
|
| 126 |
else:
|
| 127 |
face_helper.read_image(img)
|
|
@@ -150,7 +211,7 @@ if __name__ == '__main__':
|
|
| 150 |
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
| 151 |
|
| 152 |
restored_face = restored_face.astype('uint8')
|
| 153 |
-
face_helper.add_restored_face(restored_face)
|
| 154 |
|
| 155 |
# paste_back
|
| 156 |
if not args.has_aligned:
|
|
@@ -178,12 +239,36 @@ if __name__ == '__main__':
|
|
| 178 |
save_face_name = f'{basename}.png'
|
| 179 |
else:
|
| 180 |
save_face_name = f'{basename}_{idx:02d}.png'
|
|
|
|
|
|
|
| 181 |
save_restore_path = os.path.join(result_root, 'restored_faces', save_face_name)
|
| 182 |
imwrite(restored_face, save_restore_path)
|
| 183 |
|
| 184 |
# save restored img
|
| 185 |
if not args.has_aligned and restored_img is not None:
|
|
|
|
|
|
|
| 186 |
save_restore_path = os.path.join(result_root, 'final_results', f'{basename}.png')
|
| 187 |
imwrite(restored_img, save_restore_path)
|
| 188 |
|
| 189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import cv2
|
| 3 |
import argparse
|
|
|
|
| 6 |
from torchvision.transforms.functional import normalize
|
| 7 |
from basicsr.utils import imwrite, img2tensor, tensor2img
|
| 8 |
from basicsr.utils.download_util import load_file_from_url
|
| 9 |
+
from basicsr.utils.misc import gpu_is_available, get_device
|
| 10 |
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
| 11 |
+
from facelib.utils.misc import is_gray
|
| 12 |
|
| 13 |
from basicsr.utils.registry import ARCH_REGISTRY
|
| 14 |
|
|
|
|
| 17 |
}
|
| 18 |
|
| 19 |
def set_realesrgan():
|
| 20 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
| 21 |
+
from basicsr.utils.realesrgan_utils import RealESRGANer
|
| 22 |
+
|
| 23 |
+
use_half = False
|
| 24 |
+
if torch.cuda.is_available(): # set False in CPU/MPS mode
|
| 25 |
+
no_half_gpu_list = ['1650', '1660'] # set False for GPUs that don't support f16
|
| 26 |
+
if not True in [gpu in torch.cuda.get_device_name(0) for gpu in no_half_gpu_list]:
|
| 27 |
+
use_half = True
|
| 28 |
+
|
| 29 |
+
model = RRDBNet(
|
| 30 |
+
num_in_ch=3,
|
| 31 |
+
num_out_ch=3,
|
| 32 |
+
num_feat=64,
|
| 33 |
+
num_block=23,
|
| 34 |
+
num_grow_ch=32,
|
| 35 |
+
scale=2,
|
| 36 |
+
)
|
| 37 |
+
upsampler = RealESRGANer(
|
| 38 |
+
scale=2,
|
| 39 |
+
model_path="https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth",
|
| 40 |
+
model=model,
|
| 41 |
+
tile=args.bg_tile,
|
| 42 |
+
tile_pad=40,
|
| 43 |
+
pre_pad=0,
|
| 44 |
+
half=use_half
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
if not gpu_is_available(): # CPU
|
| 48 |
import warnings
|
| 49 |
+
warnings.warn('Running on CPU now! Make sure your PyTorch version matches your CUDA.'
|
| 50 |
+
'The unoptimized RealESRGAN is slow on CPU. '
|
| 51 |
+
'If you want to disable it, please remove `--bg_upsampler` and `--face_upsample` in command.',
|
| 52 |
category=RuntimeWarning)
|
| 53 |
+
return upsampler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
if __name__ == '__main__':
|
| 56 |
+
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 57 |
+
device = get_device()
|
| 58 |
parser = argparse.ArgumentParser()
|
| 59 |
|
| 60 |
+
parser.add_argument('-i', '--input_path', type=str, default='./inputs/whole_imgs',
|
| 61 |
+
help='Input image, video or folder. Default: inputs/whole_imgs')
|
| 62 |
+
parser.add_argument('-o', '--output_path', type=str, default=None,
|
| 63 |
+
help='Output folder. Default: results/<input_name>_<w>')
|
| 64 |
+
parser.add_argument('-w', '--fidelity_weight', type=float, default=0.5,
|
| 65 |
+
help='Balance the quality and fidelity. Default: 0.5')
|
| 66 |
+
parser.add_argument('-s', '--upscale', type=int, default=2,
|
| 67 |
+
help='The final upsampling scale of the image. Default: 2')
|
| 68 |
+
parser.add_argument('--has_aligned', action='store_true', help='Input are cropped and aligned faces. Default: False')
|
| 69 |
+
parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face. Default: False')
|
| 70 |
+
parser.add_argument('--draw_box', action='store_true', help='Draw the bounding box for the detected faces. Default: False')
|
| 71 |
# large det_model: 'YOLOv5l', 'retinaface_resnet50'
|
| 72 |
# small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
|
| 73 |
+
parser.add_argument('--detection_model', type=str, default='retinaface_resnet50',
|
| 74 |
+
help='Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n, dlib. \
|
| 75 |
+
Default: retinaface_resnet50')
|
| 76 |
+
parser.add_argument('--bg_upsampler', type=str, default='None', help='Background upsampler. Optional: realesrgan')
|
| 77 |
+
parser.add_argument('--face_upsample', action='store_true', help='Face upsampler after enhancement. Default: False')
|
| 78 |
parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler. Default: 400')
|
| 79 |
+
parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces. Default: None')
|
| 80 |
+
parser.add_argument('--save_video_fps', type=float, default=None, help='Frame rate for saving video. Default: None')
|
| 81 |
|
| 82 |
args = parser.parse_args()
|
| 83 |
|
| 84 |
# ------------------------ input & output ------------------------
|
| 85 |
+
w = args.fidelity_weight
|
| 86 |
+
input_video = False
|
| 87 |
+
if args.input_path.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path
|
| 88 |
+
input_img_list = [args.input_path]
|
| 89 |
+
result_root = f'results/test_img_{w}'
|
| 90 |
+
elif args.input_path.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path
|
| 91 |
+
from basicsr.utils.video_util import VideoReader, VideoWriter
|
| 92 |
+
input_img_list = []
|
| 93 |
+
vidreader = VideoReader(args.input_path)
|
| 94 |
+
image = vidreader.get_frame()
|
| 95 |
+
while image is not None:
|
| 96 |
+
input_img_list.append(image)
|
| 97 |
+
image = vidreader.get_frame()
|
| 98 |
+
audio = vidreader.get_audio()
|
| 99 |
+
fps = vidreader.get_fps() if args.save_video_fps is None else args.save_video_fps
|
| 100 |
+
video_name = os.path.basename(args.input_path)[:-4]
|
| 101 |
+
result_root = f'results/{video_name}_{w}'
|
| 102 |
+
input_video = True
|
| 103 |
+
vidreader.close()
|
| 104 |
+
else: # input img folder
|
| 105 |
+
if args.input_path.endswith('/'): # solve when path ends with /
|
| 106 |
+
args.input_path = args.input_path[:-1]
|
| 107 |
+
# scan all the jpg and png images
|
| 108 |
+
input_img_list = sorted(glob.glob(os.path.join(args.input_path, '*.[jpJP][pnPN]*[gG]')))
|
| 109 |
+
result_root = f'results/{os.path.basename(args.input_path)}_{w}'
|
| 110 |
+
|
| 111 |
+
if not args.output_path is None: # set output path
|
| 112 |
+
result_root = args.output_path
|
| 113 |
|
| 114 |
+
test_img_num = len(input_img_list)
|
| 115 |
+
if test_img_num == 0:
|
| 116 |
+
raise FileNotFoundError('No input image/video is found...\n'
|
| 117 |
+
'\tNote that --input_path for video should end with .mp4|.mov|.avi')
|
| 118 |
|
| 119 |
# ------------------ set up background upsampler ------------------
|
| 120 |
if args.bg_upsampler == 'realesrgan':
|
|
|
|
| 162 |
device=device)
|
| 163 |
|
| 164 |
# -------------------- start to processing ---------------------
|
| 165 |
+
for i, img_path in enumerate(input_img_list):
|
|
|
|
| 166 |
# clean all the intermediate results to process the next image
|
| 167 |
face_helper.clean_all()
|
| 168 |
|
| 169 |
+
if isinstance(img_path, str):
|
| 170 |
+
img_name = os.path.basename(img_path)
|
| 171 |
+
basename, ext = os.path.splitext(img_name)
|
| 172 |
+
print(f'[{i+1}/{test_img_num}] Processing: {img_name}')
|
| 173 |
+
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
| 174 |
+
else: # for video processing
|
| 175 |
+
basename = str(i).zfill(6)
|
| 176 |
+
img_name = f'{video_name}_{basename}' if input_video else basename
|
| 177 |
+
print(f'[{i+1}/{test_img_num}] Processing: {img_name}')
|
| 178 |
+
img = img_path
|
| 179 |
|
| 180 |
if args.has_aligned:
|
| 181 |
# the input faces are already cropped and aligned
|
| 182 |
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
|
| 183 |
+
face_helper.is_gray = is_gray(img, threshold=10)
|
| 184 |
+
if face_helper.is_gray:
|
| 185 |
+
print('Grayscale input: True')
|
| 186 |
face_helper.cropped_faces = [img]
|
| 187 |
else:
|
| 188 |
face_helper.read_image(img)
|
|
|
|
| 211 |
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
| 212 |
|
| 213 |
restored_face = restored_face.astype('uint8')
|
| 214 |
+
face_helper.add_restored_face(restored_face, cropped_face)
|
| 215 |
|
| 216 |
# paste_back
|
| 217 |
if not args.has_aligned:
|
|
|
|
| 239 |
save_face_name = f'{basename}.png'
|
| 240 |
else:
|
| 241 |
save_face_name = f'{basename}_{idx:02d}.png'
|
| 242 |
+
if args.suffix is not None:
|
| 243 |
+
save_face_name = f'{save_face_name[:-4]}_{args.suffix}.png'
|
| 244 |
save_restore_path = os.path.join(result_root, 'restored_faces', save_face_name)
|
| 245 |
imwrite(restored_face, save_restore_path)
|
| 246 |
|
| 247 |
# save restored img
|
| 248 |
if not args.has_aligned and restored_img is not None:
|
| 249 |
+
if args.suffix is not None:
|
| 250 |
+
basename = f'{basename}_{args.suffix}'
|
| 251 |
save_restore_path = os.path.join(result_root, 'final_results', f'{basename}.png')
|
| 252 |
imwrite(restored_img, save_restore_path)
|
| 253 |
|
| 254 |
+
# save enhanced video
|
| 255 |
+
if input_video:
|
| 256 |
+
print('Video Saving...')
|
| 257 |
+
# load images
|
| 258 |
+
video_frames = []
|
| 259 |
+
img_list = sorted(glob.glob(os.path.join(result_root, 'final_results', '*.[jp][pn]g')))
|
| 260 |
+
for img_path in img_list:
|
| 261 |
+
img = cv2.imread(img_path)
|
| 262 |
+
video_frames.append(img)
|
| 263 |
+
# write images to video
|
| 264 |
+
height, width = video_frames[0].shape[:2]
|
| 265 |
+
if args.suffix is not None:
|
| 266 |
+
video_name = f'{video_name}_{args.suffix}.png'
|
| 267 |
+
save_restore_path = os.path.join(result_root, f'{video_name}.mp4')
|
| 268 |
+
vidwriter = VideoWriter(save_restore_path, height, width, fps, audio)
|
| 269 |
+
|
| 270 |
+
for f in video_frames:
|
| 271 |
+
vidwriter.write_frame(f)
|
| 272 |
+
vidwriter.close()
|
| 273 |
+
|
| 274 |
+
print(f'\nAll results are saved in {result_root}')
|
README.md
CHANGED
|
@@ -9,4 +9,4 @@ app_file: app.py
|
|
| 9 |
pinned: false
|
| 10 |
---
|
| 11 |
|
| 12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 9 |
pinned: false
|
| 10 |
---
|
| 11 |
|
| 12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
|
@@ -16,9 +16,9 @@ from torchvision.transforms.functional import normalize
|
|
| 16 |
from basicsr.utils import imwrite, img2tensor, tensor2img
|
| 17 |
from basicsr.utils.download_util import load_file_from_url
|
| 18 |
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
| 19 |
-
from facelib.utils.misc import is_gray
|
| 20 |
from basicsr.archs.rrdbnet_arch import RRDBNet
|
| 21 |
from basicsr.utils.realesrgan_utils import RealESRGANer
|
|
|
|
| 22 |
|
| 23 |
from basicsr.utils.registry import ARCH_REGISTRY
|
| 24 |
|
|
@@ -166,9 +166,7 @@ def inference(image, face_align, background_enhance, face_upsample, upscale, cod
|
|
| 166 |
# face restoration for each cropped face
|
| 167 |
for idx, cropped_face in enumerate(face_helper.cropped_faces):
|
| 168 |
# prepare data
|
| 169 |
-
cropped_face_t = img2tensor(
|
| 170 |
-
cropped_face / 255.0, bgr2rgb=True, float32=True
|
| 171 |
-
)
|
| 172 |
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
| 173 |
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
|
| 174 |
|
|
@@ -182,12 +180,10 @@ def inference(image, face_align, background_enhance, face_upsample, upscale, cod
|
|
| 182 |
torch.cuda.empty_cache()
|
| 183 |
except RuntimeError as error:
|
| 184 |
print(f"Failed inference for CodeFormer: {error}")
|
| 185 |
-
restored_face = tensor2img(
|
| 186 |
-
cropped_face_t, rgb2bgr=True, min_max=(-1, 1)
|
| 187 |
-
)
|
| 188 |
|
| 189 |
restored_face = restored_face.astype("uint8")
|
| 190 |
-
face_helper.add_restored_face(restored_face)
|
| 191 |
|
| 192 |
# paste_back
|
| 193 |
if not has_aligned:
|
|
@@ -264,6 +260,12 @@ If you have any questions, please feel free to reach me out at <b>shangchenzhou@
|
|
| 264 |
td {
|
| 265 |
padding-right: 0px !important;
|
| 266 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
</style>
|
| 268 |
|
| 269 |
<table>
|
|
@@ -302,5 +304,5 @@ demo = gr.Interface(
|
|
| 302 |
)
|
| 303 |
|
| 304 |
DEBUG = os.getenv('DEBUG') == '1'
|
| 305 |
-
demo.launch(debug=DEBUG)
|
| 306 |
-
|
|
|
|
| 16 |
from basicsr.utils import imwrite, img2tensor, tensor2img
|
| 17 |
from basicsr.utils.download_util import load_file_from_url
|
| 18 |
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
|
|
|
| 19 |
from basicsr.archs.rrdbnet_arch import RRDBNet
|
| 20 |
from basicsr.utils.realesrgan_utils import RealESRGANer
|
| 21 |
+
from facelib.utils.misc import is_gray
|
| 22 |
|
| 23 |
from basicsr.utils.registry import ARCH_REGISTRY
|
| 24 |
|
|
|
|
| 166 |
# face restoration for each cropped face
|
| 167 |
for idx, cropped_face in enumerate(face_helper.cropped_faces):
|
| 168 |
# prepare data
|
| 169 |
+
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
|
|
|
|
|
|
| 170 |
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
| 171 |
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
|
| 172 |
|
|
|
|
| 180 |
torch.cuda.empty_cache()
|
| 181 |
except RuntimeError as error:
|
| 182 |
print(f"Failed inference for CodeFormer: {error}")
|
| 183 |
+
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
|
|
|
|
|
|
| 184 |
|
| 185 |
restored_face = restored_face.astype("uint8")
|
| 186 |
+
face_helper.add_restored_face(restored_face, cropped_face)
|
| 187 |
|
| 188 |
# paste_back
|
| 189 |
if not has_aligned:
|
|
|
|
| 260 |
td {
|
| 261 |
padding-right: 0px !important;
|
| 262 |
}
|
| 263 |
+
|
| 264 |
+
.gradio-container-4-37-2 .prose table, .gradio-container-4-37-2 .prose tr, .gradio-container-4-37-2 .prose td, .gradio-container-4-37-2 .prose th {
|
| 265 |
+
border: 0px solid #ffffff;
|
| 266 |
+
border-bottom: 0px solid #ffffff;
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
</style>
|
| 270 |
|
| 271 |
<table>
|
|
|
|
| 304 |
)
|
| 305 |
|
| 306 |
DEBUG = os.getenv('DEBUG') == '1'
|
| 307 |
+
# demo.launch(debug=DEBUG)
|
| 308 |
+
demo.launch(debug=DEBUG, share=True)
|