Spaces:
Sleeping
Sleeping
| import os | |
| import imageio | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| from omegaconf import OmegaConf | |
| from skimage.metrics import structural_similarity as ssim | |
| from collections import deque | |
| import torch | |
| import gc | |
| from diffusers import AutoencoderKL, DDIMScheduler | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from transformers import CLIPVisionModelWithProjection | |
| from models.guider import Guider | |
| from models.referencenet import ReferenceNet2DConditionModel | |
| from models.unet import UNet3DConditionModel | |
| from models.video_pipeline import VideoPipeline | |
| from dataset.val_dataset import ValDataset, val_collate_fn | |
| from rife import RIFE | |
| def load_model_state_dict(model, model_ckpt_path, name): | |
| ckpt = torch.load(model_ckpt_path, map_location="cpu") | |
| model_state_dict = model.state_dict() | |
| model_new_sd = {} | |
| count = 0 | |
| for k, v in ckpt.items(): | |
| if k in model_state_dict: | |
| count += 1 | |
| model_new_sd[k] = v | |
| miss, _ = model.load_state_dict(model_new_sd, strict=False) | |
| print(f'load {name} from {model_ckpt_path}\n - load params: {count}\n - miss params: {miss}') | |
| def frame_analysis(prev_frame, curr_frame): | |
| prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_RGB2GRAY) | |
| curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_RGB2GRAY) | |
| ssim_score = ssim(prev_gray, curr_gray) | |
| mean_diff = np.mean(np.abs(curr_frame.astype(float) - prev_frame.astype(float))) | |
| return ssim_score, mean_diff | |
| def is_anomaly(ssim_score, mean_diff, ssim_history, mean_diff_history): | |
| if len(ssim_history) < 5: | |
| return False | |
| ssim_avg = np.mean(ssim_history) | |
| mean_diff_avg = np.mean(mean_diff_history) | |
| ssim_threshold = 0.85 | |
| mean_diff_threshold = 6.0 | |
| ssim_change_threshold = 0.05 | |
| mean_diff_change_threshold = 3.0 | |
| if (ssim_score < ssim_threshold and mean_diff > mean_diff_threshold) or \ | |
| (ssim_score < ssim_avg - ssim_change_threshold and mean_diff > mean_diff_avg + mean_diff_change_threshold): | |
| return True | |
| return False | |
| def visualize(dataloader, pipeline, generator, W, H, video_length, num_inference_steps, guidance_scale, output_path, output_fps=7, limit=1, show_stats=False, anomaly_action="remove", interpolate_anomaly_multiplier=1, interpolate_result=False, interpolate_result_multiplier=1): | |
| oo_video_path = None | |
| all_video_path = None | |
| for i, batch in enumerate(dataloader): | |
| ref_frame = batch['ref_frame'][0] | |
| clip_image = batch['clip_image'][0] | |
| motions = batch['motions'][0] | |
| file_name = batch['file_name'][0] | |
| if motions is None: | |
| continue | |
| if 'lmk_name' in batch: | |
| lmk_name = batch['lmk_name'][0].split('.')[0] | |
| else: | |
| lmk_name = 'lmk' | |
| print(file_name, lmk_name) | |
| ref_frame = torch.clamp((ref_frame + 1.0) / 2.0, min=0, max=1) | |
| ref_frame = ref_frame.permute((1, 2, 3, 0)).squeeze() | |
| ref_frame = (ref_frame * 255).cpu().numpy().astype(np.uint8) | |
| ref_image = Image.fromarray(ref_frame) | |
| motions = motions.permute((1, 2, 3, 0)) | |
| motions = (motions * 255).cpu().numpy().astype(np.uint8) | |
| lmk_images = [Image.fromarray(motion) for motion in motions] | |
| preds = pipeline(ref_image=ref_image, | |
| lmk_images=lmk_images, | |
| width=W, | |
| height=H, | |
| video_length=video_length, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| generator=generator, | |
| clip_image=clip_image, | |
| ).videos | |
| preds = preds.permute((0,2,3,4,1)).squeeze(0) | |
| preds = (preds * 255).cpu().numpy().astype(np.uint8) | |
| filtered_preds = [] | |
| prev_frame = None | |
| ssim_history = deque(maxlen=5) | |
| mean_diff_history = deque(maxlen=5) | |
| normal_frames = [] | |
| # Clear memory | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| rife = RIFE() | |
| for idx, frame in enumerate(preds): | |
| if prev_frame is not None: | |
| ssim_score, mean_diff = frame_analysis(prev_frame, frame) | |
| ssim_history.append(ssim_score) | |
| mean_diff_history.append(mean_diff) | |
| if show_stats: | |
| print(f"Frame {idx}: SSIM: {ssim_score:.4f}, Mean Diff: {mean_diff:.4f}") | |
| if is_anomaly(ssim_score, mean_diff, ssim_history, mean_diff_history): | |
| print(f"Anomaly detected in frame {idx}") | |
| if anomaly_action == "remove": | |
| continue | |
| elif anomaly_action == "interpolate": | |
| if filtered_preds: | |
| interpolated = rife.interpolate_frames(filtered_preds[-1], frame, interpolate_anomaly_multiplier) | |
| filtered_preds.extend(interpolated[1:]) | |
| continue | |
| elif anomaly_action == "removeAndInterpolate": | |
| if normal_frames: | |
| last_normal = normal_frames[-1] | |
| interpolated = rife.interpolate_frames(preds[last_normal], frame, idx - last_normal + 1) | |
| filtered_preds.extend(interpolated[1:]) | |
| continue | |
| filtered_preds.append(frame) | |
| normal_frames.append(idx) | |
| prev_frame = frame | |
| # Сохраняем промежуточное видео | |
| temp_video_path = os.path.join(output_path, f"{lmk_name}_temp.mp4") | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(temp_video_path, fourcc, output_fps, (W, H)) | |
| for frame in filtered_preds: | |
| out.write(frame) | |
| out.release() | |
| # Применяем RIFE для интерполяции результата, если нужно | |
| if interpolate_result: | |
| exp = int(np.log2(interpolate_result_multiplier)) | |
| oo_video_path = os.path.join(output_path, f"{lmk_name}_oo.mp4") | |
| new_frame_count, new_fps = rife.interpolate_video(temp_video_path, oo_video_path, exp=exp, fps=output_fps*(2**exp)) | |
| video_length = new_frame_count | |
| output_fps = new_fps | |
| else: | |
| oo_video_path = temp_video_path | |
| video_length = len(filtered_preds) | |
| # Создаем видео со всеми кадрами (original, motion, reference, predicted) | |
| if 'frames' in batch: | |
| frames = batch['frames'][0] | |
| frames = torch.clamp((frames + 1.0) / 2.0, min=0, max=1) | |
| frames = frames.permute((1, 2, 3, 0)) | |
| frames = (frames * 255).cpu().numpy().astype(np.uint8) | |
| # Читаем кадры из интерполированного видео | |
| cap = cv2.VideoCapture(oo_video_path) | |
| interpolated_frames = [] | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| interpolated_frames.append(frame) | |
| cap.release() | |
| # Растягиваем оригинальные кадры и motion до количества интерполированных кадров | |
| frames_extended = [cv2.resize(frame, (W, H), interpolation=cv2.INTER_LINEAR) for frame in frames] | |
| frames_extended = frames_extended * (len(interpolated_frames) // len(frames_extended)) + frames_extended[:(len(interpolated_frames) % len(frames_extended))] | |
| motions_extended = [cv2.resize(motion, (W, H), interpolation=cv2.INTER_LINEAR) for motion in motions] | |
| motions_extended = motions_extended * (len(interpolated_frames) // len(motions_extended)) + motions_extended[:(len(interpolated_frames) % len(motions_extended))] | |
| combined = [np.concatenate((frame, motion, ref_frame, pred), axis=1) | |
| for frame, motion, pred in zip(frames_extended, motions_extended, interpolated_frames)] | |
| else: | |
| # Читаем кадры из интерполированного видео | |
| cap = cv2.VideoCapture(oo_video_path) | |
| interpolated_frames = [] | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| interpolated_frames.append(frame) | |
| cap.release() | |
| # Растягиваем motion до количества интерполированных кадров | |
| motions_extended = [cv2.resize(motion, (W, H), interpolation=cv2.INTER_LINEAR) for motion in motions] | |
| motions_extended = motions_extended * (len(interpolated_frames) // len(motions_extended)) + motions_extended[:(len(interpolated_frames) % len(motions_extended))] | |
| combined = [np.concatenate((motion, ref_frame, pred), axis=1) | |
| for motion, pred in zip(motions_extended, interpolated_frames)] | |
| all_video_path = os.path.join(output_path, f"{lmk_name}_all.mp4") | |
| out = cv2.VideoWriter(all_video_path, fourcc, output_fps, (combined[0].shape[1], combined[0].shape[0])) | |
| for frame in combined: | |
| out.write(frame) | |
| out.release() | |
| if i >= limit: | |
| break | |
| rife.unload() | |
| return oo_video_path, all_video_path, video_length, output_fps | |
| def infer(config_path, model_path, input_path, lmk_path, output_path, model_step, seed, | |
| resolution_w, resolution_h, video_length, num_inference_steps, guidance_scale, output_fps, show_stats, | |
| anomaly_action, interpolate_anomaly_multiplier, interpolate_result, interpolate_result_multiplier): | |
| config = OmegaConf.load(config_path) | |
| config.init_checkpoint = model_path | |
| config.init_num = model_step | |
| config.resolution_w = resolution_w | |
| config.resolution_h = resolution_h | |
| config.video_length = video_length | |
| if config.weight_dtype == "fp16": | |
| weight_dtype = torch.float16 | |
| elif config.weight_dtype == "fp32": | |
| weight_dtype = torch.float32 | |
| else: | |
| raise ValueError(f"Do not support weight dtype: {config.weight_dtype}") | |
| vae = AutoencoderKL.from_pretrained(config.vae_model_path).to(dtype=weight_dtype, device="cuda") | |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained(config.image_encoder_path).to(dtype=weight_dtype, device="cuda") | |
| referencenet = ReferenceNet2DConditionModel.from_pretrained_2d(config.base_model_path, | |
| referencenet_additional_kwargs=config.model.referencenet_additional_kwargs).to(device="cuda") | |
| unet = UNet3DConditionModel.from_pretrained_2d(config.base_model_path, | |
| motion_module_path=config.motion_module_path, | |
| unet_additional_kwargs=config.model.unet_additional_kwargs).to(device="cuda") | |
| lmk_guider = Guider(conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256)).to(device="cuda") | |
| load_model_state_dict(referencenet, f'{config.init_checkpoint}/referencenet.pth', 'referencenet') | |
| load_model_state_dict(unet, f'{config.init_checkpoint}/unet.pth', 'unet') | |
| load_model_state_dict(lmk_guider, f'{config.init_checkpoint}/lmk_guider.pth', 'lmk_guider') | |
| if config.enable_xformers_memory_efficient_attention: | |
| if is_xformers_available(): | |
| referencenet.enable_xformers_memory_efficient_attention() | |
| unet.enable_xformers_memory_efficient_attention() | |
| else: | |
| raise ValueError("xformers is not available. Make sure it is installed correctly") | |
| unet.set_reentrant(use_reentrant=False) | |
| referencenet.set_reentrant(use_reentrant=False) | |
| vae.eval() | |
| image_encoder.eval() | |
| unet.eval() | |
| referencenet.eval() | |
| lmk_guider.eval() | |
| sched_kwargs = OmegaConf.to_container(config.scheduler) | |
| if config.enable_zero_snr: | |
| sched_kwargs.update(rescale_betas_zero_snr=True, | |
| timestep_spacing="trailing", | |
| prediction_type="v_prediction") | |
| noise_scheduler = DDIMScheduler(**sched_kwargs) | |
| pipeline = VideoPipeline(vae=vae, | |
| image_encoder=image_encoder, | |
| referencenet=referencenet, | |
| unet=unet, | |
| lmk_guider=lmk_guider, | |
| scheduler=noise_scheduler).to(vae.device, dtype=weight_dtype) | |
| val_dataset = ValDataset( | |
| input_path=input_path, | |
| lmk_path=lmk_path, | |
| resolution_h=config.resolution_h, | |
| resolution_w=config.resolution_w | |
| ) | |
| val_dataloader = torch.utils.data.DataLoader( | |
| val_dataset, | |
| batch_size=1, | |
| num_workers=0, | |
| shuffle=False, | |
| collate_fn=val_collate_fn, | |
| ) | |
| generator = torch.Generator(device=vae.device) | |
| generator.manual_seed(seed) | |
| oo_video_path, all_video_path, new_video_length, new_output_fps = visualize( | |
| val_dataloader, | |
| pipeline, | |
| generator, | |
| W=config.resolution_w, | |
| H=config.resolution_h, | |
| video_length=config.video_length, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| output_path=output_path, | |
| output_fps=output_fps, | |
| show_stats=show_stats, | |
| anomaly_action=anomaly_action, | |
| interpolate_anomaly_multiplier=interpolate_anomaly_multiplier, | |
| interpolate_result=interpolate_result, | |
| interpolate_result_multiplier=interpolate_result_multiplier, | |
| limit=100000000 | |
| ) | |
| del vae, image_encoder, referencenet, unet, lmk_guider, pipeline | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return "Inference completed successfully", oo_video_path, all_video_path, new_video_length, new_output_fps | |
| def run_inference(config_path, model_path, input_path, lmk_path, output_path, model_step, seed, | |
| resolution_w, resolution_h, video_length, num_inference_steps=30, guidance_scale=3.5, output_fps=30, | |
| show_stats=False, anomaly_action="remove", interpolate_anomaly_multiplier=1, | |
| interpolate_result=False, interpolate_result_multiplier=1): | |
| try: | |
| # Clear memory | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return infer(config_path, model_path, input_path, lmk_path, output_path, model_step, seed, | |
| resolution_w, resolution_h, video_length, num_inference_steps, guidance_scale, output_fps, | |
| show_stats, anomaly_action, interpolate_anomaly_multiplier, | |
| interpolate_result, interpolate_result_multiplier) | |
| finally: | |
| torch.cuda.empty_cache() | |
| gc.collect() |