FollowYourEmoji / extra /inference_img.py
learnmlf's picture
feat: add emoji
40ac571
import os
import imageio
import numpy as np
from PIL import Image
import cv2
from omegaconf import OmegaConf
from skimage.metrics import structural_similarity as ssim
from collections import deque
import torch
import gc
from diffusers import AutoencoderKL, DDIMScheduler
from diffusers.utils.import_utils import is_xformers_available
from transformers import CLIPVisionModelWithProjection
from models.guider import Guider
from models.referencenet import ReferenceNet2DConditionModel
from models.unet import UNet3DConditionModel
from models.video_pipeline import VideoPipeline
from dataset.val_dataset import ValDataset, val_collate_fn
from rife import RIFE
def load_model_state_dict(model, model_ckpt_path, name):
ckpt = torch.load(model_ckpt_path, map_location="cpu")
model_state_dict = model.state_dict()
model_new_sd = {}
count = 0
for k, v in ckpt.items():
if k in model_state_dict:
count += 1
model_new_sd[k] = v
miss, _ = model.load_state_dict(model_new_sd, strict=False)
print(f'load {name} from {model_ckpt_path}\n - load params: {count}\n - miss params: {miss}')
def frame_analysis(prev_frame, curr_frame):
prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_RGB2GRAY)
curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_RGB2GRAY)
ssim_score = ssim(prev_gray, curr_gray)
mean_diff = np.mean(np.abs(curr_frame.astype(float) - prev_frame.astype(float)))
return ssim_score, mean_diff
def is_anomaly(ssim_score, mean_diff, ssim_history, mean_diff_history):
if len(ssim_history) < 5:
return False
ssim_avg = np.mean(ssim_history)
mean_diff_avg = np.mean(mean_diff_history)
ssim_threshold = 0.85
mean_diff_threshold = 6.0
ssim_change_threshold = 0.05
mean_diff_change_threshold = 3.0
if (ssim_score < ssim_threshold and mean_diff > mean_diff_threshold) or \
(ssim_score < ssim_avg - ssim_change_threshold and mean_diff > mean_diff_avg + mean_diff_change_threshold):
return True
return False
@torch.no_grad()
def visualize(dataloader, pipeline, generator, W, H, video_length, num_inference_steps, guidance_scale, output_path, output_fps=7, limit=1, show_stats=False, anomaly_action="remove", interpolate_anomaly_multiplier=1, interpolate_result=False, interpolate_result_multiplier=1):
oo_video_path = None
all_video_path = None
for i, batch in enumerate(dataloader):
ref_frame = batch['ref_frame'][0]
clip_image = batch['clip_image'][0]
motions = batch['motions'][0]
file_name = batch['file_name'][0]
if motions is None:
continue
if 'lmk_name' in batch:
lmk_name = batch['lmk_name'][0].split('.')[0]
else:
lmk_name = 'lmk'
print(file_name, lmk_name)
ref_frame = torch.clamp((ref_frame + 1.0) / 2.0, min=0, max=1)
ref_frame = ref_frame.permute((1, 2, 3, 0)).squeeze()
ref_frame = (ref_frame * 255).cpu().numpy().astype(np.uint8)
ref_image = Image.fromarray(ref_frame)
motions = motions.permute((1, 2, 3, 0))
motions = (motions * 255).cpu().numpy().astype(np.uint8)
lmk_images = [Image.fromarray(motion) for motion in motions]
preds = pipeline(ref_image=ref_image,
lmk_images=lmk_images,
width=W,
height=H,
video_length=video_length,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
clip_image=clip_image,
).videos
preds = preds.permute((0,2,3,4,1)).squeeze(0)
preds = (preds * 255).cpu().numpy().astype(np.uint8)
filtered_preds = []
prev_frame = None
ssim_history = deque(maxlen=5)
mean_diff_history = deque(maxlen=5)
normal_frames = []
# Clear memory
torch.cuda.empty_cache()
gc.collect()
rife = RIFE()
for idx, frame in enumerate(preds):
if prev_frame is not None:
ssim_score, mean_diff = frame_analysis(prev_frame, frame)
ssim_history.append(ssim_score)
mean_diff_history.append(mean_diff)
if show_stats:
print(f"Frame {idx}: SSIM: {ssim_score:.4f}, Mean Diff: {mean_diff:.4f}")
if is_anomaly(ssim_score, mean_diff, ssim_history, mean_diff_history):
print(f"Anomaly detected in frame {idx}")
if anomaly_action == "remove":
continue
elif anomaly_action == "interpolate":
if filtered_preds:
interpolated = rife.interpolate_frames(filtered_preds[-1], frame, interpolate_anomaly_multiplier)
filtered_preds.extend(interpolated[1:])
continue
elif anomaly_action == "removeAndInterpolate":
if normal_frames:
last_normal = normal_frames[-1]
interpolated = rife.interpolate_frames(preds[last_normal], frame, idx - last_normal + 1)
filtered_preds.extend(interpolated[1:])
continue
filtered_preds.append(frame)
normal_frames.append(idx)
prev_frame = frame
# Сохраняем промежуточное видео
temp_video_path = os.path.join(output_path, f"{lmk_name}_temp.mp4")
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(temp_video_path, fourcc, output_fps, (W, H))
for frame in filtered_preds:
out.write(frame)
out.release()
# Применяем RIFE для интерполяции результата, если нужно
if interpolate_result:
exp = int(np.log2(interpolate_result_multiplier))
oo_video_path = os.path.join(output_path, f"{lmk_name}_oo.mp4")
new_frame_count, new_fps = rife.interpolate_video(temp_video_path, oo_video_path, exp=exp, fps=output_fps*(2**exp))
video_length = new_frame_count
output_fps = new_fps
else:
oo_video_path = temp_video_path
video_length = len(filtered_preds)
# Создаем видео со всеми кадрами (original, motion, reference, predicted)
if 'frames' in batch:
frames = batch['frames'][0]
frames = torch.clamp((frames + 1.0) / 2.0, min=0, max=1)
frames = frames.permute((1, 2, 3, 0))
frames = (frames * 255).cpu().numpy().astype(np.uint8)
# Читаем кадры из интерполированного видео
cap = cv2.VideoCapture(oo_video_path)
interpolated_frames = []
while True:
ret, frame = cap.read()
if not ret:
break
interpolated_frames.append(frame)
cap.release()
# Растягиваем оригинальные кадры и motion до количества интерполированных кадров
frames_extended = [cv2.resize(frame, (W, H), interpolation=cv2.INTER_LINEAR) for frame in frames]
frames_extended = frames_extended * (len(interpolated_frames) // len(frames_extended)) + frames_extended[:(len(interpolated_frames) % len(frames_extended))]
motions_extended = [cv2.resize(motion, (W, H), interpolation=cv2.INTER_LINEAR) for motion in motions]
motions_extended = motions_extended * (len(interpolated_frames) // len(motions_extended)) + motions_extended[:(len(interpolated_frames) % len(motions_extended))]
combined = [np.concatenate((frame, motion, ref_frame, pred), axis=1)
for frame, motion, pred in zip(frames_extended, motions_extended, interpolated_frames)]
else:
# Читаем кадры из интерполированного видео
cap = cv2.VideoCapture(oo_video_path)
interpolated_frames = []
while True:
ret, frame = cap.read()
if not ret:
break
interpolated_frames.append(frame)
cap.release()
# Растягиваем motion до количества интерполированных кадров
motions_extended = [cv2.resize(motion, (W, H), interpolation=cv2.INTER_LINEAR) for motion in motions]
motions_extended = motions_extended * (len(interpolated_frames) // len(motions_extended)) + motions_extended[:(len(interpolated_frames) % len(motions_extended))]
combined = [np.concatenate((motion, ref_frame, pred), axis=1)
for motion, pred in zip(motions_extended, interpolated_frames)]
all_video_path = os.path.join(output_path, f"{lmk_name}_all.mp4")
out = cv2.VideoWriter(all_video_path, fourcc, output_fps, (combined[0].shape[1], combined[0].shape[0]))
for frame in combined:
out.write(frame)
out.release()
if i >= limit:
break
rife.unload()
return oo_video_path, all_video_path, video_length, output_fps
@torch.no_grad()
def infer(config_path, model_path, input_path, lmk_path, output_path, model_step, seed,
resolution_w, resolution_h, video_length, num_inference_steps, guidance_scale, output_fps, show_stats,
anomaly_action, interpolate_anomaly_multiplier, interpolate_result, interpolate_result_multiplier):
config = OmegaConf.load(config_path)
config.init_checkpoint = model_path
config.init_num = model_step
config.resolution_w = resolution_w
config.resolution_h = resolution_h
config.video_length = video_length
if config.weight_dtype == "fp16":
weight_dtype = torch.float16
elif config.weight_dtype == "fp32":
weight_dtype = torch.float32
else:
raise ValueError(f"Do not support weight dtype: {config.weight_dtype}")
vae = AutoencoderKL.from_pretrained(config.vae_model_path).to(dtype=weight_dtype, device="cuda")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(config.image_encoder_path).to(dtype=weight_dtype, device="cuda")
referencenet = ReferenceNet2DConditionModel.from_pretrained_2d(config.base_model_path,
referencenet_additional_kwargs=config.model.referencenet_additional_kwargs).to(device="cuda")
unet = UNet3DConditionModel.from_pretrained_2d(config.base_model_path,
motion_module_path=config.motion_module_path,
unet_additional_kwargs=config.model.unet_additional_kwargs).to(device="cuda")
lmk_guider = Guider(conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256)).to(device="cuda")
load_model_state_dict(referencenet, f'{config.init_checkpoint}/referencenet.pth', 'referencenet')
load_model_state_dict(unet, f'{config.init_checkpoint}/unet.pth', 'unet')
load_model_state_dict(lmk_guider, f'{config.init_checkpoint}/lmk_guider.pth', 'lmk_guider')
if config.enable_xformers_memory_efficient_attention:
if is_xformers_available():
referencenet.enable_xformers_memory_efficient_attention()
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
unet.set_reentrant(use_reentrant=False)
referencenet.set_reentrant(use_reentrant=False)
vae.eval()
image_encoder.eval()
unet.eval()
referencenet.eval()
lmk_guider.eval()
sched_kwargs = OmegaConf.to_container(config.scheduler)
if config.enable_zero_snr:
sched_kwargs.update(rescale_betas_zero_snr=True,
timestep_spacing="trailing",
prediction_type="v_prediction")
noise_scheduler = DDIMScheduler(**sched_kwargs)
pipeline = VideoPipeline(vae=vae,
image_encoder=image_encoder,
referencenet=referencenet,
unet=unet,
lmk_guider=lmk_guider,
scheduler=noise_scheduler).to(vae.device, dtype=weight_dtype)
val_dataset = ValDataset(
input_path=input_path,
lmk_path=lmk_path,
resolution_h=config.resolution_h,
resolution_w=config.resolution_w
)
val_dataloader = torch.utils.data.DataLoader(
val_dataset,
batch_size=1,
num_workers=0,
shuffle=False,
collate_fn=val_collate_fn,
)
generator = torch.Generator(device=vae.device)
generator.manual_seed(seed)
oo_video_path, all_video_path, new_video_length, new_output_fps = visualize(
val_dataloader,
pipeline,
generator,
W=config.resolution_w,
H=config.resolution_h,
video_length=config.video_length,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
output_path=output_path,
output_fps=output_fps,
show_stats=show_stats,
anomaly_action=anomaly_action,
interpolate_anomaly_multiplier=interpolate_anomaly_multiplier,
interpolate_result=interpolate_result,
interpolate_result_multiplier=interpolate_result_multiplier,
limit=100000000
)
del vae, image_encoder, referencenet, unet, lmk_guider, pipeline
torch.cuda.empty_cache()
gc.collect()
return "Inference completed successfully", oo_video_path, all_video_path, new_video_length, new_output_fps
def run_inference(config_path, model_path, input_path, lmk_path, output_path, model_step, seed,
resolution_w, resolution_h, video_length, num_inference_steps=30, guidance_scale=3.5, output_fps=30,
show_stats=False, anomaly_action="remove", interpolate_anomaly_multiplier=1,
interpolate_result=False, interpolate_result_multiplier=1):
try:
# Clear memory
torch.cuda.empty_cache()
gc.collect()
return infer(config_path, model_path, input_path, lmk_path, output_path, model_step, seed,
resolution_w, resolution_h, video_length, num_inference_steps, guidance_scale, output_fps,
show_stats, anomaly_action, interpolate_anomaly_multiplier,
interpolate_result, interpolate_result_multiplier)
finally:
torch.cuda.empty_cache()
gc.collect()