Spaces:
Runtime error
Runtime error
| import sys | |
| sys.path.append(".") | |
| from PIL import Image | |
| import torch | |
| from torchvision.transforms import ToTensor, Compose, Resize, Normalize | |
| from torch.nn import functional as F | |
| from opensora.models.ae.videobase import CausalVAEModel | |
| import argparse | |
| import numpy as np | |
| def preprocess(video_data: torch.Tensor, short_size: int = 128) -> torch.Tensor: | |
| transform = Compose( | |
| [ | |
| ToTensor(), | |
| Normalize((0.5), (0.5)), | |
| Resize(size=short_size), | |
| ] | |
| ) | |
| outputs = transform(video_data) | |
| outputs = outputs.unsqueeze(0).unsqueeze(2) | |
| return outputs | |
| def main(args: argparse.Namespace): | |
| image_path = args.image_path | |
| resolution = args.resolution | |
| device = args.device | |
| vqvae = CausalVAEModel.load_from_checkpoint(args.ckpt) | |
| vqvae.eval() | |
| vqvae = vqvae.to(device) | |
| with torch.no_grad(): | |
| x_vae = preprocess(Image.open(image_path), resolution) | |
| x_vae = x_vae.to(device) | |
| latents = vqvae.encode(x_vae) | |
| recon = vqvae.decode(latents.sample()) | |
| x = recon[0, :, 0, :, :] | |
| x = x.squeeze() | |
| x = x.detach().cpu().numpy() | |
| x = np.clip(x, -1, 1) | |
| x = (x + 1) / 2 | |
| x = (255*x).astype(np.uint8) | |
| x = x.transpose(1,2,0) | |
| image = Image.fromarray(x) | |
| image.save(args.rec_path) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--image-path', type=str, default='') | |
| parser.add_argument('--rec-path', type=str, default='') | |
| parser.add_argument('--ckpt', type=str, default='') | |
| parser.add_argument('--resolution', type=int, default=336) | |
| parser.add_argument('--device', type=str, default='cuda') | |
| args = parser.parse_args() | |
| main(args) | |