Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| from __future__ import annotations | |
| import argparse | |
| import functools | |
| import os | |
| import pathlib | |
| import subprocess | |
| import sys | |
| import tarfile | |
| import gradio as gr | |
| import huggingface_hub | |
| import numpy as np | |
| import PIL.Image | |
| import torch | |
| if os.getenv('SYSTEM') == 'spaces': | |
| with open('patch') as f: | |
| subprocess.run('patch -p1'.split(), cwd='gan-control', stdin=f) | |
| sys.path.insert(0, 'gan-control/src') | |
| from gan_control.inference.controller import Controller | |
| TITLE = 'amazon-research/gan-control' | |
| DESCRIPTION = '''This is an unofficial demo for https://github.com/amazon-research/gan-control. | |
| Expected execution time on Hugging Face Spaces: 7s (for one image) | |
| ''' | |
| ARTICLE = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.gan-control" alt="visitor badge"/></center>' | |
| TOKEN = os.environ['TOKEN'] | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--device', type=str, default='cpu') | |
| parser.add_argument('--theme', type=str) | |
| parser.add_argument('--live', action='store_true') | |
| parser.add_argument('--share', action='store_true') | |
| parser.add_argument('--port', type=int) | |
| parser.add_argument('--disable-queue', | |
| dest='enable_queue', | |
| action='store_false') | |
| parser.add_argument('--allow-flagging', type=str, default='never') | |
| return parser.parse_args() | |
| def download_models() -> None: | |
| model_dir = pathlib.Path('controller_age015id025exp02hai04ori02gam15') | |
| if not model_dir.exists(): | |
| path = huggingface_hub.hf_hub_download( | |
| 'hysts/gan-control', | |
| 'controller_age015id025exp02hai04ori02gam15.tar.gz', | |
| use_auth_token=TOKEN) | |
| with tarfile.open(path) as f: | |
| f.extractall() | |
| def run( | |
| seed: int, | |
| truncation: float, | |
| yaw: int, | |
| pitch: int, | |
| age: int, | |
| hair_color_r: float, | |
| hair_color_g: float, | |
| hair_color_b: float, | |
| nrows: int, | |
| ncols: int, | |
| controller: Controller, | |
| device: torch.device, | |
| ) -> PIL.Image.Image: | |
| seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max)) | |
| batch_size = nrows * ncols | |
| latent_size = controller.config.model_config['latent_size'] | |
| latent = torch.from_numpy( | |
| np.random.RandomState(seed).randn(batch_size, | |
| latent_size)).float().to(device) | |
| initial_image_tensors, initial_latent_z, initial_latent_w = controller.gen_batch( | |
| latent=latent, truncation=truncation) | |
| res0 = controller.make_resized_grid_image(initial_image_tensors, | |
| nrow=ncols) | |
| pose_control = torch.tensor([[yaw, pitch, 0]], dtype=torch.float32) | |
| image_tensors, _, modified_latent_w = controller.gen_batch_by_controls( | |
| latent=initial_latent_w, | |
| input_is_latent=True, | |
| orientation=pose_control) | |
| res1 = controller.make_resized_grid_image(image_tensors, nrow=ncols) | |
| age_control = torch.tensor([[age]], dtype=torch.float32) | |
| image_tensors, _, modified_latent_w = controller.gen_batch_by_controls( | |
| latent=initial_latent_w, input_is_latent=True, age=age_control) | |
| res2 = controller.make_resized_grid_image(image_tensors, nrow=ncols) | |
| hair_color = torch.tensor([[hair_color_r, hair_color_g, hair_color_b]], | |
| dtype=torch.float32) / 255 | |
| hair_color = torch.clamp(hair_color, 0, 1) | |
| image_tensors, _, modified_latent_w = controller.gen_batch_by_controls( | |
| latent=initial_latent_w, input_is_latent=True, hair=hair_color) | |
| res3 = controller.make_resized_grid_image(image_tensors, nrow=ncols) | |
| return res0, res1, res2, res3 | |
| def main(): | |
| args = parse_args() | |
| device = torch.device(args.device) | |
| download_models() | |
| path = 'controller_age015id025exp02hai04ori02gam15/' | |
| controller = Controller(path, device) | |
| func = functools.partial(run, controller=controller, device=device) | |
| func = functools.update_wrapper(func, run) | |
| gr.Interface( | |
| func, | |
| [ | |
| gr.inputs.Number(default=0, label='Seed'), | |
| gr.inputs.Slider(0, 1, step=0.1, default=0.7, label='Truncation'), | |
| gr.inputs.Slider(-90, 90, step=1, default=30, label='Yaw'), | |
| gr.inputs.Slider(-90, 90, step=1, default=0, label='Pitch'), | |
| gr.inputs.Slider(15, 75, step=1, default=75, label='Age'), | |
| gr.inputs.Slider( | |
| 0, 255, step=1, default=186, label='Hair Color (R)'), | |
| gr.inputs.Slider( | |
| 0, 255, step=1, default=158, label='Hair Color (G)'), | |
| gr.inputs.Slider( | |
| 0, 255, step=1, default=92, label='Hair Color (B)'), | |
| gr.inputs.Slider(1, 3, step=1, default=1, label='Number of Rows'), | |
| gr.inputs.Slider( | |
| 1, 5, step=1, default=5, label='Number of Columns'), | |
| ], | |
| [ | |
| gr.outputs.Image(type='pil', label='Generated Image'), | |
| gr.outputs.Image(type='pil', label='Head Pose Controlled'), | |
| gr.outputs.Image(type='pil', label='Age Controlled'), | |
| gr.outputs.Image(type='pil', label='Hair Color Controlled'), | |
| ], | |
| title=TITLE, | |
| description=DESCRIPTION, | |
| article=ARTICLE, | |
| theme=args.theme, | |
| allow_flagging=args.allow_flagging, | |
| live=args.live, | |
| ).launch( | |
| enable_queue=args.enable_queue, | |
| server_port=args.port, | |
| share=args.share, | |
| ) | |
| if __name__ == '__main__': | |
| main() | |