Spaces:
Sleeping
Sleeping
File size: 6,258 Bytes
40ac571 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import cv2
import torch
from torch.nn import functional as F
import warnings
import os
warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)
if torch.cuda.is_available():
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
class RIFE:
_instance = None
def __new__(cls, model_path='./ckpt_models/rife'):
if cls._instance is None:
cls._instance = super(RIFE, cls).__new__(cls)
cls._instance.initialize(model_path)
return cls._instance
def initialize(self, model_path):
try:
try:
from model.RIFE_HDv2 import Model
self.model = Model()
self.model.load_model(model_path, -1)
print("Loaded v2.x HD model.")
except:
from train_log.RIFE_HDv3 import Model
self.model = Model()
self.model.load_model(model_path, -1)
print("Loaded v3.x HD model.")
except:
try:
from model.RIFE_HD import Model
self.model = Model()
self.model.load_model(model_path, -1)
print("Loaded v1.x HD model")
except:
from model.RIFE import Model
self.model = Model()
self.model.load_model(model_path, -1)
print("Loaded ArXiv-RIFE model")
self.model.eval()
self.model.device()
def interpolate(self, img0, img1, exp=4, ratio=0, rthreshold=0.02, rmaxcycles=8):
if isinstance(img0, str) and isinstance(img1, str):
if img0.endswith('.exr') and img1.endswith('.exr'):
img0 = cv2.imread(img0, cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
img1 = cv2.imread(img1, cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device)).unsqueeze(0)
img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device)).unsqueeze(0)
else:
img0 = cv2.imread(img0, cv2.IMREAD_UNCHANGED)
img1 = cv2.imread(img1, cv2.IMREAD_UNCHANGED)
img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
elif isinstance(img0, torch.Tensor) and isinstance(img1, torch.Tensor):
img0 = img0.to(device)
img1 = img1.to(device)
else:
raise ValueError("Input images must be either file paths or torch tensors")
n, c, h, w = img0.shape
ph = ((h - 1) // 32 + 1) * 32
pw = ((w - 1) // 32 + 1) * 32
padding = (0, pw - w, 0, ph - h)
img0 = F.pad(img0, padding)
img1 = F.pad(img1, padding)
if ratio:
img_list = [img0]
img0_ratio = 0.0
img1_ratio = 1.0
if ratio <= img0_ratio + rthreshold / 2:
middle = img0
elif ratio >= img1_ratio - rthreshold / 2:
middle = img1
else:
tmp_img0 = img0
tmp_img1 = img1
for inference_cycle in range(rmaxcycles):
middle = self.model.inference(tmp_img0, tmp_img1)
middle_ratio = (img0_ratio + img1_ratio) / 2
if ratio - (rthreshold / 2) <= middle_ratio <= ratio + (rthreshold / 2):
break
if ratio > middle_ratio:
tmp_img0 = middle
img0_ratio = middle_ratio
else:
tmp_img1 = middle
img1_ratio = middle_ratio
img_list.append(middle)
img_list.append(img1)
else:
img_list = [img0, img1]
for i in range(exp):
tmp = []
for j in range(len(img_list) - 1):
mid = self.model.inference(img_list[j], img_list[j + 1])
tmp.append(img_list[j])
tmp.append(mid)
tmp.append(img1)
img_list = tmp
return [img[:, :, :h, :w] for img in img_list]
def unload(self):
if hasattr(self, 'model'):
del self.model
torch.cuda.empty_cache()
print("RIFE model unloaded and CUDA cache cleared.")
@classmethod
def reset(cls):
if cls._instance is not None:
cls._instance.unload()
cls._instance = None
print("RIFE instance reset.")
def save_images(img_list, output_dir='output', img0_path='', img1_path=''):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
for i, img in enumerate(img_list):
if img0_path.endswith('.exr') and img1_path.endswith('.exr'):
cv2.imwrite(os.path.join(output_dir, f'img{i}.exr'), img[0].cpu().numpy().transpose(1, 2, 0), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
else:
cv2.imwrite(os.path.join(output_dir, f'img{i}.png'), (img[0] * 255).byte().cpu().numpy().transpose(1, 2, 0))
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
parser.add_argument('--img', dest='img', nargs=2, required=True)
parser.add_argument('--exp', default=4, type=int)
parser.add_argument('--ratio', default=0, type=float, help='inference ratio between two images with 0 - 1 range')
parser.add_argument('--rthreshold', default=0.02, type=float, help='returns image when actual ratio falls in given range threshold')
parser.add_argument('--rmaxcycles', default=8, type=int, help='limit max number of bisectional cycles')
parser.add_argument('--model', dest='modelDir', type=str, default='./ckpt_models/rife', help='directory with trained model files')
args = parser.parse_args()
rife = RIFE(args.modelDir)
img_list = rife.interpolate(args.img[0], args.img[1], args.exp, args.ratio, args.rthreshold, args.rmaxcycles)
save_images(img_list, img0_path=args.img[0], img1_path=args.img[1])
rife.unload()
|