Spaces:
Running
Running
| # Prediction interface for Cog ⚙️ | |
| # https://github.com/replicate/cog/blob/main/docs/python.md | |
| import shutil | |
| import gradio as gr | |
| # from cog import BasePredictor, Input, Path | |
| import insightface | |
| import onnxruntime | |
| from insightface.app import FaceAnalysis | |
| import cv2 | |
| import gfpgan | |
| import tempfile | |
| import time | |
| import uuid | |
| from typing import Any, Union | |
| from loggers import logger, request_id as _request_id | |
| import ssl | |
| from datetime import datetime | |
| import traceback | |
| import torch | |
| import os | |
| import requests | |
| import subprocess | |
| import sys | |
| from PIL import Image | |
| import numpy as np | |
| ssl._create_default_https_context = ssl._create_unverified_context | |
| if sys.platform == 'darwin': | |
| cache_file_dir = '/tmp/file' | |
| else: | |
| cache_file_dir = '/src/file' | |
| # os.makedirs(cache_file_dir, exist_ok=True) | |
| def img_url_to_local_path(img_url, file_path=None): | |
| filename = img_url.split('/')[-1] | |
| max_count = 3 | |
| count = 0 | |
| if file_path is None: | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=filename) | |
| temp_file_name = temp_file.name | |
| else: | |
| temp_file_name = file_path | |
| while True: | |
| count += 1 | |
| try: | |
| res = requests.get(img_url, timeout=60) | |
| res.raise_for_status() | |
| with open(temp_file_name, "wb") as f: | |
| f.write(res.content) | |
| return temp_file_name | |
| except Exception as e: | |
| logger.error(e) | |
| if count >= max_count: | |
| msg = f'request {max_count} time url: {img_url} failed, please check' | |
| logger.error(msg) | |
| raise Exception(msg) | |
| def delete_files_day_ago(cache_days=10): | |
| command = f"find {cache_file_dir} -type f -ctime +{cache_days} -exec rm {{}} \;" | |
| result = subprocess.run(command, shell=True, capture_output=True, text=True) | |
| logger.info(result.stdout) | |
| def image_format_by_path(image_path): | |
| image = Image.open(image_path) | |
| image_format = image.format | |
| if not image_format: | |
| image_format = 'jpg' | |
| elif image_format == "JPEG": | |
| image_format = 'jpg' | |
| else: | |
| image_format = image_format.lower() | |
| return image_format | |
| def local_file_for_url(url, cache_days=10): | |
| filename = url.split('/')[-1] | |
| _, ext = filename.split('.') | |
| file_path = f'{cache_file_dir}/{filename}' | |
| if not os.path.exists(file_path): | |
| img_url_to_local_path(url, file_path) | |
| logger.info(f'download file to {file_path}') | |
| delete_files_day_ago(cache_days) | |
| else: | |
| logger.info(f'cache file {file_path}') | |
| return file_path | |
| class Predictor: | |
| def __init__(self): | |
| self.det_thresh = 0.1 | |
| def setup(self): | |
| self.face_swapper = insightface.model_zoo.get_model('cache/inswapper_128.onnx', providers=onnxruntime.get_available_providers()) | |
| self.face_enhancer = gfpgan.GFPGANer(model_path='cache/GFPGANv1.4.pth', upscale=1) | |
| self.face_analyser = FaceAnalysis(name='buffalo_l') | |
| def get_face(self, img_data, image_type='target'): | |
| try: | |
| logger.info(self.det_thresh) | |
| self.face_analyser.prepare(ctx_id=0, det_thresh=0.5) | |
| if image_type == 'source': | |
| self.face_analyser.prepare(ctx_id=0, det_thresh=self.det_thresh) | |
| analysed = self.face_analyser.get(img_data) | |
| logger.info(f'face num: {len(analysed)}') | |
| if len(analysed) == 0: | |
| msg = 'no face' | |
| logger.error(msg) | |
| raise Exception(msg) | |
| largest = max(analysed, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1])) | |
| return largest | |
| except Exception as e: | |
| logger.error(str(e)) | |
| raise Exception(str(e)) | |
| def enhance_face(self, target_face, target_frame, weight=0.3): | |
| start_x, start_y, end_x, end_y = map(int, target_face['bbox']) | |
| padding_x = int((end_x - start_x) * 0.5) | |
| padding_y = int((end_y - start_y) * 0.5) | |
| start_x = max(0, start_x - padding_x) | |
| start_y = max(0, start_y - padding_y) | |
| end_x = max(0, end_x + padding_x) | |
| end_y = max(0, end_y + padding_y) | |
| temp_face = target_frame[start_y:end_y, start_x:end_x] | |
| if temp_face.size: | |
| _, _, temp_face = self.face_enhancer.enhance( | |
| temp_face, | |
| paste_back=True, | |
| weight=weight | |
| ) | |
| target_frame[start_y:end_y, start_x:end_x] = temp_face | |
| return target_frame | |
| def predict( | |
| self, | |
| source_image_path, | |
| target_image_path, | |
| enhance_face, | |
| # request_id: str = Input(description="request_id", default=""), | |
| # det_thresh: float = Input(description="det_thresh default 0.1", default=0.1), | |
| # local_target: Path = Input(description="local target image", default=None), | |
| # local_source: Path = Input(description="local source image", default=None), | |
| # cache_days: int = Input(description="cache days default 10", default=10), | |
| # weight: float = Input(description="weight default 0.5", default=0.5) | |
| ) -> Any: | |
| """Run a single prediction on the model""" | |
| request_id = None | |
| det_thresh = 0.1 | |
| cache_days = 10 | |
| weight = 0.5 | |
| device = 'cuda' if torch.cuda.is_available() else 'mps' | |
| logger.info(f'device: {device}, det_thresh:{det_thresh}') | |
| try: | |
| self.det_thresh = det_thresh | |
| start_time = time.time() | |
| if not request_id: | |
| request_id = str(uuid.uuid4()) | |
| _request_id.set(request_id) | |
| frame = cv2.imread(str(target_image_path)) | |
| source_frame = cv2.imread(str(source_image_path)) | |
| source_face = self.get_face(source_frame, image_type='source') | |
| target_face = self.get_face(frame) | |
| try: | |
| logger.info(f'{frame.shape}, {target_face.shape}, {source_face.shape}') | |
| except Exception as e: | |
| logger.error(f"printing shapes failed, error:{str(e)}") | |
| raise Exception(str(e)) | |
| ext = image_format_by_path(target_image_path) | |
| size = os.path.getsize(target_image_path) | |
| logger.info(f'origin {size/1024}k') | |
| result = self.face_swapper.get(frame, target_face, source_face, paste_back=True ) | |
| if enhance_face: | |
| result = self.enhance_face(target_face, result, weight) | |
| # _, _, result = self.face_enhancer.enhance( | |
| # result, | |
| # paste_back=True | |
| # ) | |
| out_path = f"{tempfile.mkdtemp()}/{uuid.uuid4()}.{ext}" | |
| cv2.imwrite(str(out_path), result) | |
| return Image.open(out_path) | |
| size = os.path.getsize(out_path) | |
| logger.info(f'result {size / 1024}k') | |
| cost_time = time.time() - start_time | |
| logger.info(f'total time: {cost_time * 1000} ms') | |
| data = {'code': 200, 'msg': 'succeed', 'image': out_path, 'status': 'succeed'} | |
| return data | |
| except Exception as e: | |
| logger.error(traceback.format_exc()) | |
| data = {'code': 500, 'msg': str(e), 'image': '', 'status': 'failed'} | |
| logger.error(f"{str(e)}") | |
| return data | |
| def swap_faces(source_image_path, target_image_path, enhance_face): | |
| predictor = Predictor() | |
| predictor.setup() | |
| return predictor.predict( | |
| source_image_path, | |
| target_image_path, | |
| enhance_face | |
| ) | |
| if __name__ == "__main__": | |
| demo = gr.Interface( | |
| fn=swap_faces, | |
| inputs=[ | |
| gr.Image(type="filepath"), | |
| gr.Image(type="filepath"), | |
| gr.Checkbox(label="Enhance Face", value=True), | |
| # gr.Checkbox(label="Enhance Frame", value=True), | |
| ], | |
| outputs=[ | |
| gr.Image( | |
| type="pil", | |
| show_download_button=True, | |
| ) | |
| ], | |
| title="Swap Faces", | |
| allow_flagging="never" | |
| ) | |
| demo.launch() |