Spaces:
Runtime error
Runtime error
| # Copyright 2023 Bytedance Ltd. and/or its affiliates | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from .pipeline_st_stable_diffusion import SpatioTemporalStableDiffusionPipeline | |
| from typing import Callable, List, Optional, Union | |
| from diffusers.schedulers import ( | |
| DDIMScheduler, | |
| DPMSolverMultistepScheduler, | |
| EulerAncestralDiscreteScheduler, | |
| EulerDiscreteScheduler, | |
| LMSDiscreteScheduler, | |
| PNDMScheduler, | |
| ) | |
| from transformers import DPTForDepthEstimation | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput | |
| from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler | |
| import torch | |
| from einops import rearrange, repeat | |
| import decord | |
| import cv2 | |
| import random | |
| import numpy as np | |
| from ..models.unet_3d_condition import UNetPseudo3DConditionModel | |
| from ..models.controlnet3d import ControlNet3DModel | |
| class Controlnet3DStableDiffusionPipeline(SpatioTemporalStableDiffusionPipeline): | |
| def __init__( | |
| self, | |
| vae: AutoencoderKL, | |
| text_encoder: CLIPTextModel, | |
| tokenizer: CLIPTokenizer, | |
| unet: UNetPseudo3DConditionModel, | |
| controlnet: ControlNet3DModel, | |
| scheduler: Union[ | |
| DDIMScheduler, | |
| PNDMScheduler, | |
| LMSDiscreteScheduler, | |
| EulerDiscreteScheduler, | |
| EulerAncestralDiscreteScheduler, | |
| DPMSolverMultistepScheduler, | |
| ], | |
| annotator_model=None, | |
| ): | |
| super().__init__(vae, text_encoder, tokenizer, unet, scheduler) | |
| self.annotator_model = annotator_model | |
| self.controlnet = controlnet | |
| self.unet = unet | |
| self.vae = vae | |
| self.tokenizer = tokenizer | |
| self.text_encoder = text_encoder | |
| self.scheduler = scheduler | |
| self.register_modules( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| unet=unet, | |
| controlnet=controlnet, | |
| scheduler=scheduler, | |
| ) | |
| def get_frames_preprocess(data_path, num_frames=24, sampling_rate=1, begin_indice=0, return_np=False): | |
| vr = decord.VideoReader(data_path,) | |
| n_images = len(vr) | |
| fps_vid = round(vr.get_avg_fps()) | |
| frame_indices = [begin_indice + i*sampling_rate for i in range(num_frames)] # 随机取n帧 | |
| while n_images <= frame_indices[-1]: | |
| # 超过视频长度,采样率减小直至不超过。 | |
| sampling_rate -= 1 | |
| if sampling_rate == 0: | |
| # NOTE 边界检查 | |
| return None, None | |
| frame_indices = [i*sampling_rate for i in range(num_frames)] | |
| frames = vr.get_batch(frame_indices).asnumpy() | |
| if return_np: | |
| return frames, fps_vid | |
| frames = torch.from_numpy(frames).div(255) * 2 - 1 | |
| frames = rearrange(frames, "f h w c -> c f h w").unsqueeze(0) | |
| return frames, fps_vid | |
| def get_canny_edge_map(self, frames, ): | |
| # (b f) c h w" | |
| # from tensor to numpy | |
| inputs = frames.cpu().numpy() | |
| inputs = rearrange(inputs, 'f c h w -> f h w c') | |
| # inputs from [-1, 1] to [0, 255] | |
| inputs = (inputs + 1) * 127.5 | |
| inputs = inputs.astype(np.uint8) | |
| lower_threshold = 100 | |
| higher_threshold = 200 | |
| edge_images = np.stack([cv2.Canny(inp, lower_threshold, higher_threshold) for inp in inputs]) | |
| # from numpy to tensors | |
| edge_images = torch.from_numpy(edge_images).unsqueeze(1) # f, 1, h, w | |
| edge_images = edge_images.div(255)*2 - 1 | |
| # print(torch.max(out_images), torch.min(out_images), out_images.dtype) | |
| return edge_images.to(dtype= self.controlnet.dtype, device=self.controlnet.device) | |
| def get_depth_map(self, frames, height, width, return_standard_norm=False ): | |
| """ | |
| frames should be like: (f c h w), you may turn b f c h w -> (b f) c h w first | |
| """ | |
| h,w = height, width | |
| inputs = torch.nn.functional.interpolate( | |
| frames, | |
| size=(384, 384), | |
| mode="bicubic", | |
| antialias=True, | |
| ) | |
| # 转类型和设备 | |
| inputs = inputs.to(dtype= self.annotator_model.dtype, device=self.annotator_model.device) | |
| outputs = self.annotator_model(inputs) | |
| predicted_depths = outputs.predicted_depth | |
| # interpolate to original size | |
| predictions = torch.nn.functional.interpolate( | |
| predicted_depths.unsqueeze(1), | |
| size=(h, w), | |
| mode="bicubic", | |
| ) | |
| # normalize output | |
| if return_standard_norm: | |
| depth_min = torch.amin(predictions, dim=[1, 2, 3], keepdim=True) | |
| depth_max = torch.amax(predictions, dim=[1, 2, 3], keepdim=True) | |
| predictions = 2.0 * (predictions - depth_min) / (depth_max - depth_min) - 1.0 | |
| else: | |
| predictions -= torch.min(predictions) | |
| predictions /= torch.max(predictions) | |
| return predictions | |
| def get_hed_map(self, frames,): | |
| if isinstance(frames, torch.Tensor): | |
| # 输入的就是 b c h w的tensor 范围是-1~1,需要转换为0~1 | |
| frames = (frames + 1) / 2 | |
| #rgb转bgr | |
| bgr_frames = frames.clone() | |
| bgr_frames[:, 0, :, :] = frames[:, 2, :, :] | |
| bgr_frames[:, 2, :, :] = frames[:, 0, :, :] | |
| edge = self.annotator_model(bgr_frames) # 范围也是0~1 | |
| return edge | |
| else: | |
| assert frames.ndim == 3 | |
| frames = frames[:, :, ::-1].copy() | |
| with torch.no_grad(): | |
| image_hed = torch.from_numpy(frames).to(next(self.annotator_model.parameters()).device, dtype=next(self.annotator_model.parameters()).dtype ) | |
| image_hed = image_hed / 255.0 | |
| image_hed = rearrange(image_hed, 'h w c -> 1 c h w') | |
| edge = self.annotator_model(image_hed)[0] | |
| edge = (edge.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8) | |
| return edge[0] | |
| def get_pose_map(self, frames,): | |
| if isinstance(frames, torch.Tensor): | |
| # 输入的就是 b c h w的tensor 范围是-1~1,需要转换为0~1 | |
| frames = (frames + 1) / 2 | |
| np_frames = frames.cpu().numpy() * 255 | |
| np_frames = np.array(np_frames, dtype=np.uint8) | |
| np_frames = rearrange(np_frames, 'f c h w-> f h w c') | |
| poses = np.stack([self.annotator_model(inp) for inp in np_frames]) | |
| else: | |
| poses = self.annotator_model(frames) | |
| return poses | |
| def get_timesteps(self, num_inference_steps, strength,): | |
| # get the original timestep using init_timestep | |
| init_timestep = min(int(num_inference_steps * strength), num_inference_steps) | |
| t_start = max(num_inference_steps - init_timestep, 0) | |
| timesteps = self.scheduler.timesteps[t_start:] | |
| return timesteps, num_inference_steps - t_start | |
| def __call__( | |
| self, | |
| prompt: Union[str, List[str]], | |
| controlnet_hint = None, | |
| fps_labels = None, | |
| height: Optional[int] = None, | |
| width: Optional[int] = None, | |
| num_inference_steps: int = 50, | |
| clip_length: int = 8, # NOTE clip_length和images的帧数一致。 | |
| guidance_scale: float = 7.5, | |
| negative_prompt: Optional[Union[str, List[str]]] = None, | |
| num_images_per_prompt: Optional[int] = 1, | |
| eta: float = 0.0, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| latents: Optional[torch.FloatTensor] = None, | |
| output_type: Optional[str] = "pil", | |
| return_dict: bool = True, | |
| callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | |
| callback_steps: Optional[int] = 1, | |
| cross_attention_kwargs = None, | |
| video_scale: float = 0.0, | |
| controlnet_conditioning_scale: float = 1.0, | |
| fix_first_frame=True, | |
| first_frame_output = None , # 也可以允许挑好图后传入。 | |
| first_frame_output_latent = None, | |
| first_frame_control_hint = None, # 维持第一帧 | |
| add_first_frame_by_concat = False, | |
| controlhint_in_uncond = False, | |
| init_same_noise_per_frame=False, | |
| init_noise_by_residual_thres=0.0, | |
| images=None, | |
| in_domain=False, # 是否调用视频模型生成图片 | |
| residual_control_steps=1, | |
| first_frame_ddim_strength=1.0, | |
| return_last_latent = False, | |
| ): | |
| ''' | |
| add origin video frames to get depth maps | |
| ''' | |
| if fix_first_frame and first_frame_output is None and first_frame_output_latent is None: | |
| first_frame_output = self.__call__( | |
| prompt=prompt, | |
| controlnet_hint=controlnet_hint[:,:,0,:,:] if not in_domain else controlnet_hint[:,:,0:1,:,:], | |
| # b c f h w | |
| num_inference_steps=20, | |
| width=width, | |
| height=height, | |
| guidance_scale=guidance_scale, | |
| num_images_per_prompt=1, | |
| generator=generator, | |
| fix_first_frame=False, | |
| controlhint_in_uncond=controlhint_in_uncond, | |
| ).images[0] | |
| if first_frame_output is not None: | |
| if isinstance(first_frame_output, list): | |
| first_frame_output = first_frame_output[0] | |
| first_frame_output = torch.from_numpy(np.array(first_frame_output)).div(255) * 2 - 1 | |
| first_frame_output = rearrange(first_frame_output, "h w c -> c h w").unsqueeze(0) # FIXME 目前不允许多个batch 先设置为1 | |
| first_frame_output = first_frame_output.to(dtype= self.vae.dtype, device=self.vae.device) | |
| first_frame_output_latent = self.vae.encode(first_frame_output).latent_dist.sample() | |
| first_frame_output_latent = first_frame_output_latent * 0.18215 | |
| # 0. Default height and width to unet | |
| height = height or self.unet.config.sample_size * self.vae_scale_factor | |
| width = width or self.unet.config.sample_size * self.vae_scale_factor | |
| # 1. Check inputs. Raise error if not correct | |
| self.check_inputs(prompt, height, width, callback_steps) | |
| # 2. Define call parameters | |
| batch_size = 1 if isinstance(prompt, str) else len(prompt) | |
| device = self._execution_device | |
| # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | |
| # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | |
| # corresponds to doing no classifier free guidance. | |
| do_classifier_free_guidance = guidance_scale > 5.0 | |
| # 3. Encode input prompt | |
| text_embeddings = self._encode_prompt( | |
| prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt | |
| ) | |
| # 4. Prepare timesteps | |
| self.scheduler.set_timesteps(num_inference_steps, device=device) | |
| timesteps = self.scheduler.timesteps | |
| # 5. Prepare latent variables | |
| num_channels_latents = self.unet.in_channels | |
| if controlnet_hint is not None: | |
| if len(controlnet_hint.shape) == 5: | |
| clip_length = controlnet_hint.shape[2] | |
| else: | |
| clip_length = 0 | |
| latents = self.prepare_latents( | |
| batch_size * num_images_per_prompt, | |
| num_channels_latents, | |
| clip_length, | |
| height, | |
| width, | |
| text_embeddings.dtype, | |
| device, | |
| generator, | |
| latents, | |
| ) | |
| latents_dtype = latents.dtype | |
| if len(latents.shape) == 5 and init_same_noise_per_frame: | |
| latents[:,:,1:,:,:] = latents[:,:,0:1,:,:] | |
| if len(latents.shape) == 5 and init_noise_by_residual_thres > 0.0 and images is not None: | |
| images = images.to(device=device, dtype=latents_dtype) # b c f h w | |
| image_residual = torch.abs(images[:,:,1:,:,:] - images[:,:,:-1,:,:]) | |
| images = rearrange(images, "b c f h w -> (b f) c h w") | |
| # norm residual | |
| image_residual = image_residual / torch.max(image_residual) | |
| image_residual = rearrange(image_residual, "b c f h w -> (b f) c h w") | |
| image_residual = torch.nn.functional.interpolate( | |
| image_residual, | |
| size=(latents.shape[-2], latents.shape[-1]), | |
| mode='bilinear') | |
| image_residual = torch.mean(image_residual, dim=1) | |
| image_residual_mask = (image_residual > init_noise_by_residual_thres).float() | |
| image_residual_mask = repeat(image_residual_mask, '(b f) h w -> b f h w', b=batch_size) | |
| image_residual_mask = repeat(image_residual_mask, 'b f h w -> b c f h w', c=latents.shape[1]) | |
| # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | |
| extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | |
| # 7. Denoising loop | |
| num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | |
| with self.progress_bar(total=num_inference_steps) as progress_bar: | |
| if fix_first_frame: | |
| if add_first_frame_by_concat: | |
| if len(first_frame_output_latent.shape) == 4: | |
| latents = torch.cat([first_frame_output_latent.unsqueeze(2), latents], dim=2) | |
| else: | |
| latents = torch.cat([first_frame_output_latent, latents], dim=2) | |
| if first_frame_control_hint is not None: | |
| controlnet_hint = torch.cat([first_frame_control_hint, controlnet_hint], dim=2) | |
| else: | |
| controlnet_hint = torch.cat([controlnet_hint[:,:,0:1 ,:,:], controlnet_hint], dim=2) | |
| if controlhint_in_uncond: | |
| controlnet_hint = torch.cat([controlnet_hint] * 2) if do_classifier_free_guidance else controlnet_hint | |
| for i, t in enumerate(timesteps): | |
| # expand the latents if we are doing classifier free guidance | |
| if i<residual_control_steps and len(latents.shape) == 5 and init_noise_by_residual_thres > 0.0 and images is not None : | |
| if first_frame_ddim_strength < 1.0 and i == 0 : | |
| # NOTE DDIM to get the first noise | |
| first_frame_output_latent_DDIM = first_frame_output_latent.clone() | |
| full_noise_timestep, _ = self.get_timesteps(num_inference_steps, strength=first_frame_ddim_strength) | |
| latent_timestep = full_noise_timestep[:1].repeat(batch_size * num_images_per_prompt) | |
| first_frame_output_latent_DDIM = self.scheduler.add_noise(first_frame_output_latent_DDIM, latents[:,:,0,:,:], latent_timestep) | |
| latents[:,:,0,:,:]=first_frame_output_latent_DDIM | |
| begin_frame = 1 | |
| for n_frame in range(begin_frame, latents.shape[2]): | |
| latents[:,:, n_frame, :, :] = \ | |
| (latents[:,:, n_frame, :, :] - latents[:,:, n_frame-1, :, :]) \ | |
| * image_residual_mask[:,:, n_frame-1, :, :] + \ | |
| latents[:,:, n_frame-1, :, :] | |
| if fix_first_frame: | |
| latents[:,:,0 ,:,:] = first_frame_output_latent | |
| latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
| if controlnet_hint is not None: | |
| down_block_res_samples, mid_block_res_sample = self.controlnet( | |
| latent_model_input, | |
| t, | |
| encoder_hidden_states=text_embeddings, | |
| controlnet_cond=controlnet_hint, | |
| return_dict=False, | |
| ) | |
| down_block_res_samples = [ | |
| down_block_res_sample * controlnet_conditioning_scale | |
| for down_block_res_sample in down_block_res_samples | |
| ] | |
| mid_block_res_sample *= controlnet_conditioning_scale | |
| noise_pred = self.unet( | |
| latent_model_input, | |
| t, | |
| encoder_hidden_states=text_embeddings, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| down_block_additional_residuals=down_block_res_samples, | |
| mid_block_additional_residual=mid_block_res_sample, | |
| ).sample.to(dtype=latents_dtype) | |
| else: | |
| # predict the noise residual | |
| noise_pred = self.unet( | |
| latent_model_input, | |
| t, | |
| encoder_hidden_states=text_embeddings, | |
| ).sample.to(dtype=latents_dtype) | |
| if video_scale > 0 and controlnet_hint is not None: | |
| bsz = latents.shape[0] | |
| f = latents.shape[2] | |
| # 逐帧预测 | |
| latent_model_input_single_frame = rearrange(latent_model_input, 'b c f h w -> (b f) c h w') | |
| text_embeddings_single_frame = torch.cat([text_embeddings] * f, dim=0) | |
| control_maps_single_frame = rearrange(controlnet_hint, 'b c f h w -> (b f) c h w') | |
| latent_model_input_single_frame = latent_model_input_single_frame.chunk(2, dim=0)[0] | |
| text_embeddings_single_frame = text_embeddings_single_frame.chunk(2, dim=0)[0] | |
| if controlhint_in_uncond: | |
| control_maps_single_frame = control_maps_single_frame.chunk(2, dim=0)[0] | |
| down_block_res_samples_single_frame, mid_block_res_sample_single_frame = self.controlnet( | |
| latent_model_input_single_frame, | |
| t, | |
| encoder_hidden_states=text_embeddings_single_frame, | |
| controlnet_cond=control_maps_single_frame, | |
| return_dict=False, | |
| ) | |
| down_block_res_samples_single_frame = [ | |
| down_block_res_sample_single_frame * controlnet_conditioning_scale | |
| for down_block_res_sample_single_frame in down_block_res_samples_single_frame | |
| ] | |
| mid_block_res_sample_single_frame *= controlnet_conditioning_scale | |
| noise_pred_single_frame_uncond = self.unet( | |
| latent_model_input_single_frame, | |
| t, | |
| encoder_hidden_states = text_embeddings_single_frame, | |
| down_block_additional_residuals=down_block_res_samples_single_frame, | |
| mid_block_additional_residual=mid_block_res_sample_single_frame, | |
| ).sample | |
| noise_pred_single_frame_uncond = rearrange(noise_pred_single_frame_uncond, '(b f) c h w -> b c f h w', f=f) | |
| # perform guidance | |
| if do_classifier_free_guidance: | |
| if video_scale > 0 and controlnet_hint is not None: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_single_frame_uncond + video_scale * ( | |
| noise_pred_uncond - noise_pred_single_frame_uncond | |
| ) + guidance_scale * ( | |
| noise_pred_text - noise_pred_uncond | |
| ) | |
| else: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * ( | |
| noise_pred_text - noise_pred_uncond | |
| ) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | |
| # call the callback, if provided | |
| if i == len(timesteps) - 1 or ( | |
| (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 | |
| ): | |
| progress_bar.update() | |
| if callback is not None and i % callback_steps == 0: | |
| callback(i, t, latents) | |
| # 8. Post-processing | |
| image = self.decode_latents(latents) | |
| if add_first_frame_by_concat: | |
| image = image[:,1:,:,:,:] | |
| # 9. Run safety checker | |
| has_nsfw_concept = None | |
| # 10. Convert to PIL | |
| if output_type == "pil": | |
| image = self.numpy_to_pil(image) | |
| if not return_dict: | |
| return (image, has_nsfw_concept) | |
| if return_last_latent: | |
| last_latent = latents[:,:,-1,:,:] | |
| return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept), last_latent | |
| else: | |
| return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) | |