Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| import torch, warnings, glob, os | |
| import numpy as np | |
| from PIL import Image | |
| from einops import repeat, reduce | |
| from typing import Optional, Union | |
| from dataclasses import dataclass | |
| import numpy as np | |
| from PIL import Image | |
| from typing import Optional | |
| class BasePipeline(torch.nn.Module): | |
| def __init__( | |
| self, | |
| device="cuda", | |
| torch_dtype=torch.float16, | |
| height_division_factor=64, | |
| width_division_factor=64, | |
| time_division_factor=None, | |
| time_division_remainder=None, | |
| ): | |
| super().__init__() | |
| # The device and torch_dtype is used for the storage of intermediate variables, not models. | |
| self.device = device | |
| self.torch_dtype = torch_dtype | |
| # The following parameters are used for shape check. | |
| self.height_division_factor = height_division_factor | |
| self.width_division_factor = width_division_factor | |
| self.time_division_factor = time_division_factor | |
| self.time_division_remainder = time_division_remainder | |
| self.vram_management_enabled = False | |
| def to(self, *args, **kwargs): | |
| device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( | |
| *args, **kwargs | |
| ) | |
| if device is not None: | |
| self.device = device | |
| if dtype is not None: | |
| self.torch_dtype = dtype | |
| super().to(*args, **kwargs) | |
| return self | |
| def check_resize_height_width(self, height, width, num_frames=None): | |
| # Shape check | |
| if height % self.height_division_factor != 0: | |
| height = ( | |
| (height + self.height_division_factor - 1) | |
| // self.height_division_factor | |
| * self.height_division_factor | |
| ) | |
| print( | |
| f"height % {self.height_division_factor} != 0. We round it up to {height}." | |
| ) | |
| if width % self.width_division_factor != 0: | |
| width = ( | |
| (width + self.width_division_factor - 1) | |
| // self.width_division_factor | |
| * self.width_division_factor | |
| ) | |
| print( | |
| f"width % {self.width_division_factor} != 0. We round it up to {width}." | |
| ) | |
| if num_frames is None: | |
| return height, width | |
| else: | |
| if num_frames % self.time_division_factor != self.time_division_remainder: | |
| num_frames = ( | |
| (num_frames + self.time_division_factor - 1) | |
| // self.time_division_factor | |
| * self.time_division_factor | |
| + self.time_division_remainder | |
| ) | |
| print( | |
| f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}." | |
| ) | |
| return height, width, num_frames | |
| def preprocess_image( | |
| self, | |
| image, | |
| torch_dtype=None, | |
| device=None, | |
| pattern="B C H W", | |
| min_value=-1, | |
| max_value=1, | |
| ): | |
| # Transform a PIL.Image to torch.Tensor | |
| image = torch.Tensor(np.array(image, dtype=np.float32)) | |
| image = image.to( | |
| dtype=torch_dtype or self.torch_dtype, device=device or self.device | |
| ) | |
| image = image * ((max_value - min_value) / 255) + min_value | |
| image = repeat( | |
| image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {}) | |
| ) | |
| return image | |
| def preprocess_video( | |
| self, | |
| video, | |
| torch_dtype=None, | |
| device=None, | |
| pattern="B C T H W", | |
| min_value=-1, | |
| max_value=1, | |
| ): | |
| # Transform a list of PIL.Image to torch.Tensor | |
| video = [ | |
| self.preprocess_image( | |
| image, | |
| torch_dtype=torch_dtype, | |
| device=device, | |
| min_value=min_value, | |
| max_value=max_value, | |
| ) | |
| for image in video | |
| ] | |
| video = torch.stack(video, dim=pattern.index("T") // 2) | |
| return video | |
| def vae_output_to_image( | |
| self, vae_output, pattern="B C H W", min_value=-1, max_value=1 | |
| ): | |
| # Transform a torch.Tensor to PIL.Image | |
| if pattern != "H W C": | |
| vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean") | |
| image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip( | |
| 0, 255 | |
| ) | |
| image = image.to(device="cpu", dtype=torch.uint8) | |
| image = Image.fromarray(image.numpy()) | |
| return image | |
| def vae_output_to_video( | |
| self, vae_output, pattern="B C T H W", min_value=-1, max_value=1 | |
| ): | |
| # Transform a torch.Tensor to list of PIL.Image | |
| if pattern != "T H W C": | |
| vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean") | |
| video = [ | |
| self.vae_output_to_image( | |
| image, pattern="H W C", min_value=min_value, max_value=max_value | |
| ) | |
| for image in vae_output | |
| ] | |
| return video | |
| def load_models_to_device(self, model_names=[]): | |
| if self.vram_management_enabled: | |
| # offload models | |
| for name, model in self.named_children(): | |
| if name not in model_names: | |
| if ( | |
| hasattr(model, "vram_management_enabled") | |
| and model.vram_management_enabled | |
| ): | |
| for module in model.modules(): | |
| if hasattr(module, "offload"): | |
| module.offload() | |
| else: | |
| model.cpu() | |
| torch.cuda.empty_cache() | |
| # onload models | |
| for name, model in self.named_children(): | |
| if name in model_names: | |
| if ( | |
| hasattr(model, "vram_management_enabled") | |
| and model.vram_management_enabled | |
| ): | |
| for module in model.modules(): | |
| if hasattr(module, "onload"): | |
| module.onload() | |
| else: | |
| model.to(self.device) | |
| def generate_noise( | |
| self, | |
| shape, | |
| seed=None, | |
| rand_device="cpu", | |
| rand_torch_dtype=torch.float32, | |
| device=None, | |
| torch_dtype=None, | |
| ): | |
| # Initialize Gaussian noise | |
| generator = ( | |
| None if seed is None else torch.Generator(rand_device).manual_seed(seed) | |
| ) | |
| noise = torch.randn( | |
| shape, generator=generator, device=rand_device, dtype=rand_torch_dtype | |
| ) | |
| noise = noise.to( | |
| dtype=torch_dtype or self.torch_dtype, device=device or self.device | |
| ) | |
| return noise | |
| def enable_cpu_offload(self): | |
| warnings.warn( | |
| "`enable_cpu_offload` will be deprecated. Please use `enable_vram_management`." | |
| ) | |
| self.vram_management_enabled = True | |
| def get_vram(self): | |
| return torch.cuda.mem_get_info(self.device)[1] / (1024**3) | |
| def freeze_except(self, model_names): | |
| for name, model in self.named_children(): | |
| if name in model_names: | |
| model.train() | |
| model.requires_grad_(True) | |
| else: | |
| model.eval() | |
| model.requires_grad_(False) | |
| class ModelConfig: | |
| path: Union[str, list[str]] = None | |
| model_id: str = None | |
| origin_file_pattern: Union[str, list[str]] = None | |
| download_resource: str = "ModelScope" | |
| offload_device: Optional[Union[str, torch.device]] = None | |
| offload_dtype: Optional[torch.dtype] = None | |
| local_model_path: str = None | |
| skip_download: bool = False | |
| def download_if_necessary(self, use_usp=False): | |
| if self.path is None: | |
| # Check model_id and origin_file_pattern | |
| if self.model_id is None: | |
| raise ValueError( | |
| f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""" | |
| ) | |
| # Skip if not in rank 0 | |
| if use_usp: | |
| import torch.distributed as dist | |
| skip_download = self.skip_download or dist.get_rank() != 0 | |
| else: | |
| skip_download = self.skip_download | |
| # Check whether the origin path is a folder | |
| if self.origin_file_pattern is None or self.origin_file_pattern == "": | |
| self.origin_file_pattern = "" | |
| allow_file_pattern = None | |
| is_folder = True | |
| elif isinstance( | |
| self.origin_file_pattern, str | |
| ) and self.origin_file_pattern.endswith("/"): | |
| allow_file_pattern = self.origin_file_pattern + "*" | |
| is_folder = True | |
| else: | |
| allow_file_pattern = self.origin_file_pattern | |
| is_folder = False | |
| # Download | |
| if not skip_download: | |
| if self.local_model_path is None: | |
| self.local_model_path = "./models" | |
| downloaded_files = glob.glob( | |
| self.origin_file_pattern, | |
| root_dir=os.path.join(self.local_model_path, self.model_id), | |
| ) | |
| snapshot_download( | |
| self.model_id, | |
| local_dir=os.path.join(self.local_model_path, self.model_id), | |
| allow_file_pattern=allow_file_pattern, | |
| ignore_file_pattern=downloaded_files, | |
| local_files_only=False, | |
| ) | |
| # Let rank 1, 2, ... wait for rank 0 | |
| if use_usp: | |
| import torch.distributed as dist | |
| dist.barrier(device_ids=[dist.get_rank()]) | |
| # Return downloaded files | |
| if is_folder: | |
| self.path = os.path.join( | |
| self.local_model_path, self.model_id, self.origin_file_pattern | |
| ) | |
| else: | |
| self.path = glob.glob( | |
| os.path.join( | |
| self.local_model_path, self.model_id, self.origin_file_pattern | |
| ) | |
| ) | |
| if isinstance(self.path, list) and len(self.path) == 1: | |
| self.path = self.path[0] | |
| class PipelineUnit: | |
| def __init__( | |
| self, | |
| seperate_cfg: bool = False, | |
| take_over: bool = False, | |
| input_params: tuple[str] = None, | |
| input_params_posi: dict[str, str] = None, | |
| input_params_nega: dict[str, str] = None, | |
| onload_model_names: tuple[str] = None, | |
| ): | |
| self.seperate_cfg = seperate_cfg | |
| self.take_over = take_over | |
| self.input_params = input_params | |
| self.input_params_posi = input_params_posi | |
| self.input_params_nega = input_params_nega | |
| self.onload_model_names = onload_model_names | |
| def process( | |
| self, pipe: BasePipeline, inputs: dict, positive=True, **kwargs | |
| ) -> dict: | |
| raise NotImplementedError("`process` is not implemented.") | |
| class PipelineUnitRunner: | |
| def __init__(self): | |
| pass | |
| def __call__( | |
| self, | |
| unit: PipelineUnit, | |
| pipe: BasePipeline, | |
| inputs_shared: dict, | |
| inputs_posi: dict, | |
| inputs_nega: dict, | |
| ) -> tuple[dict, dict]: | |
| if unit.take_over: | |
| # Let the pipeline unit take over this function. | |
| inputs_shared, inputs_posi, inputs_nega = unit.process( | |
| pipe, | |
| inputs_shared=inputs_shared, | |
| inputs_posi=inputs_posi, | |
| inputs_nega=inputs_nega, | |
| ) | |
| elif unit.seperate_cfg: | |
| # Positive side | |
| processor_inputs = { | |
| name: inputs_posi.get(name_) | |
| for name, name_ in unit.input_params_posi.items() | |
| } | |
| if unit.input_params is not None: | |
| for name in unit.input_params: | |
| processor_inputs[name] = inputs_shared.get(name) | |
| processor_outputs = unit.process(pipe, **processor_inputs) | |
| inputs_posi.update(processor_outputs) | |
| # Negative side | |
| if inputs_shared["cfg_scale"] != 1: | |
| processor_inputs = { | |
| name: inputs_nega.get(name_) | |
| for name, name_ in unit.input_params_nega.items() | |
| } | |
| if unit.input_params is not None: | |
| for name in unit.input_params: | |
| processor_inputs[name] = inputs_shared.get(name) | |
| processor_outputs = unit.process(pipe, **processor_inputs) | |
| inputs_nega.update(processor_outputs) | |
| else: | |
| inputs_nega.update(processor_outputs) | |
| else: | |
| processor_inputs = { | |
| name: inputs_shared.get(name) for name in unit.input_params | |
| } | |
| processor_outputs = unit.process(pipe, **processor_inputs) | |
| inputs_shared.update(processor_outputs) | |
| return inputs_shared, inputs_posi, inputs_nega | |
