Spaces:
Runtime error
Runtime error
| import argparse | |
| import os | |
| import torch | |
| from PIL import Image, ImageFilter | |
| from transformers import CLIPTextModel | |
| from diffusers import DPMSolverMultistepScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel | |
| parser = argparse.ArgumentParser(description="Inference") | |
| parser.add_argument( | |
| "--model_path", | |
| type=str, | |
| default=None, | |
| required=True, | |
| help="Path to pretrained model or model identifier from huggingface.co/models.", | |
| ) | |
| parser.add_argument( | |
| "--validation_image", | |
| type=str, | |
| default=None, | |
| required=True, | |
| help="The directory of the validation image", | |
| ) | |
| parser.add_argument( | |
| "--validation_mask", | |
| type=str, | |
| default=None, | |
| required=True, | |
| help="The directory of the validation mask", | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default="./test-infer/", | |
| help="The output directory where predictions are saved", | |
| ) | |
| parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible inference.") | |
| args = parser.parse_args() | |
| if __name__ == "__main__": | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| generator = None | |
| # create & load model | |
| pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float32, revision=None | |
| ) | |
| pipe.unet = UNet2DConditionModel.from_pretrained( | |
| args.model_path, | |
| subfolder="unet", | |
| revision=None, | |
| ) | |
| pipe.text_encoder = CLIPTextModel.from_pretrained( | |
| args.model_path, | |
| subfolder="text_encoder", | |
| revision=None, | |
| ) | |
| pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
| pipe = pipe.to("cuda") | |
| if args.seed is not None: | |
| generator = torch.Generator(device="cuda").manual_seed(args.seed) | |
| image = Image.open(args.validation_image) | |
| mask_image = Image.open(args.validation_mask) | |
| results = pipe( | |
| ["a photo of sks"] * 16, | |
| image=image, | |
| mask_image=mask_image, | |
| num_inference_steps=25, | |
| guidance_scale=5, | |
| generator=generator, | |
| ).images | |
| erode_kernel = ImageFilter.MaxFilter(3) | |
| mask_image = mask_image.filter(erode_kernel) | |
| blur_kernel = ImageFilter.BoxBlur(1) | |
| mask_image = mask_image.filter(blur_kernel) | |
| for idx, result in enumerate(results): | |
| result = Image.composite(result, image, mask_image) | |
| result.save(f"{args.output_dir}/{idx}.png") | |
| del pipe | |
| torch.cuda.empty_cache() | |