Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Upload 33 files
Browse files- LICENSE +21 -0
- README.md +165 -12
- assets/motivation.jpg +0 -0
- assets/the_great_wall.jpg +0 -0
- assets/user_study.jpg +0 -0
- assets/vbench.jpg +0 -0
- diffusion_schedulers/__init__.py +2 -0
- diffusion_schedulers/scheduling_cosine_ddpm.py +137 -0
- diffusion_schedulers/scheduling_flow_matching.py +298 -0
- pyramid_dit/__init__.py +3 -0
- pyramid_dit/modeling_embedding.py +390 -0
- pyramid_dit/modeling_mmdit_block.py +672 -0
- pyramid_dit/modeling_normalization.py +179 -0
- pyramid_dit/modeling_pyramid_mmdit.py +487 -0
- pyramid_dit/modeling_text_encoder.py +140 -0
- pyramid_dit/pyramid_dit_for_video_gen_pipeline.py +672 -0
- requirements.txt +32 -0
- trainer_misc/__init__.py +25 -0
- trainer_misc/communicate.py +58 -0
- trainer_misc/sp_utils.py +98 -0
- trainer_misc/utils.py +382 -0
- utils.py +457 -0
- video_generation_demo.ipynb +181 -0
- video_vae/__init__.py +2 -0
- video_vae/context_parallel_ops.py +172 -0
- video_vae/modeling_block.py +760 -0
- video_vae/modeling_causal_conv.py +139 -0
- video_vae/modeling_causal_vae.py +625 -0
- video_vae/modeling_discriminator.py +122 -0
- video_vae/modeling_enc_dec.py +422 -0
- video_vae/modeling_loss.py +192 -0
- video_vae/modeling_lpips.py +120 -0
- video_vae/modeling_resnet.py +729 -0
    	
        LICENSE
    ADDED
    
    | @@ -0,0 +1,21 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            MIT License
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            Copyright (c) 2024 Yang Jin
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            Permission is hereby granted, free of charge, to any person obtaining a copy
         | 
| 6 | 
            +
            of this software and associated documentation files (the "Software"), to deal
         | 
| 7 | 
            +
            in the Software without restriction, including without limitation the rights
         | 
| 8 | 
            +
            to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         | 
| 9 | 
            +
            copies of the Software, and to permit persons to whom the Software is
         | 
| 10 | 
            +
            furnished to do so, subject to the following conditions:
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            The above copyright notice and this permission notice shall be included in all
         | 
| 13 | 
            +
            copies or substantial portions of the Software.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         | 
| 16 | 
            +
            IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         | 
| 17 | 
            +
            FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         | 
| 18 | 
            +
            AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         | 
| 19 | 
            +
            LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         | 
| 20 | 
            +
            OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         | 
| 21 | 
            +
            SOFTWARE.
         | 
    	
        README.md
    CHANGED
    
    | @@ -1,12 +1,165 @@ | |
| 1 | 
            -
             | 
| 2 | 
            -
             | 
| 3 | 
            -
             | 
| 4 | 
            -
             | 
| 5 | 
            -
             | 
| 6 | 
            -
             | 
| 7 | 
            -
             | 
| 8 | 
            -
             | 
| 9 | 
            -
             | 
| 10 | 
            -
             | 
| 11 | 
            -
             | 
| 12 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            <div align="center">
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # ⚡️Pyramid Flow⚡️
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            [[Paper]](https://arxiv.org/abs/2410.05954) [[Project Page ✨]](https://pyramid-flow.github.io) [[Model 🤗]](https://huggingface.co/rain1011/pyramid-flow-sd3)
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            </div>
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            This is the official repository for Pyramid Flow, a training-efficient **Autoregressive Video Generation** method based on **Flow Matching**. By training only on **open-source datasets**, it can generate high-quality 10-second videos at 768p resolution and 24 FPS, and naturally supports image-to-video generation.
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            <table class="center" border="0" style="width: 100%; text-align: left;">
         | 
| 12 | 
            +
            <tr>
         | 
| 13 | 
            +
              <th>10s, 768p, 24fps</th>
         | 
| 14 | 
            +
              <th>5s, 768p, 24fps</th>
         | 
| 15 | 
            +
              <th>Image-to-video</th>
         | 
| 16 | 
            +
            </tr>
         | 
| 17 | 
            +
            <tr>
         | 
| 18 | 
            +
              <td><video src="https://github.com/user-attachments/assets/9935da83-ae56-4672-8747-0f46e90f7b2b" autoplay muted loop playsinline></video></td>
         | 
| 19 | 
            +
              <td><video src="https://github.com/user-attachments/assets/3412848b-64db-4d9e-8dbf-11403f6d02c5" autoplay muted loop playsinline></video></td>
         | 
| 20 | 
            +
              <td><video src="https://github.com/user-attachments/assets/3bd7251f-7b2c-4bee-951d-656fdb45f427" autoplay muted loop playsinline></video></td>
         | 
| 21 | 
            +
            </tr>
         | 
| 22 | 
            +
            </table>
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            ## News
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            * `COMING SOON` ⚡️⚡️⚡️ Training code for both the Video VAE and DiT; New model checkpoints trained from scratch.
         | 
| 27 | 
            +
              
         | 
| 28 | 
            +
              > We are training Pyramid Flow from scratch to fix human structure issues related to the currently adopted SD3 initialization and hope to release it in the next few days.
         | 
| 29 | 
            +
            * `2024.10.10`  🚀🚀🚀 We release the [technical report](https://arxiv.org/abs/2410.05954), [project page](https://pyramid-flow.github.io) and [model checkpoint](https://huggingface.co/rain1011/pyramid-flow-sd3) of Pyramid Flow.
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            ## Introduction
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            Existing video diffusion models operate at full resolution, spending a lot of computation on very noisy latents. By contrast, our method harnesses the flexibility of flow matching ([Lipman et al., 2023](https://openreview.net/forum?id=PqvMRDCJT9t); [Liu et al., 2023](https://openreview.net/forum?id=XVjTT1nw5z); [Albergo & Vanden-Eijnden, 2023](https://openreview.net/forum?id=li7qeBbCR1t)) to interpolate between latents of different resolutions and noise levels, allowing for simultaneous generation and decompression of visual content with better computational efficiency. The entire framework is end-to-end optimized with a single DiT ([Peebles & Xie, 2023](http://openaccess.thecvf.com/content/ICCV2023/html/Peebles_Scalable_Diffusion_Models_with_Transformers_ICCV_2023_paper.html)), generating high-quality 10-second videos at 768p resolution and 24 FPS within 20.7k A100 GPU training hours.
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            ## Usage
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            You can directly download the model from [Huggingface](https://huggingface.co/rain1011/pyramid-flow-sd3). We provide both model checkpoints for 768p and 384p video generation. The 384p checkpoint supports 5-second video generation at 24FPS, while the 768p checkpoint supports up to 10-second video generation at 24FPS.
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            ```python
         | 
| 42 | 
            +
            from huggingface_hub import snapshot_download
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            model_path = 'PATH'   # The local directory to save downloaded checkpoint
         | 
| 45 | 
            +
            snapshot_download("rain1011/pyramid-flow-sd3", local_dir=model_path, local_dir_use_symlinks=False, repo_type='model')
         | 
| 46 | 
            +
            ```
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            To use our model, please follow the inference code in `video_generation_demo.ipynb` at [this link](https://github.com/jy0205/Pyramid-Flow/blob/main/video_generation_demo.ipynb). We further simplify it into the following two-step procedure. First, load the downloaded model:
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            ```python
         | 
| 52 | 
            +
            import torch
         | 
| 53 | 
            +
            from PIL import Image
         | 
| 54 | 
            +
            from pyramid_dit import PyramidDiTForVideoGeneration
         | 
| 55 | 
            +
            from diffusers.utils import load_image, export_to_video
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            torch.cuda.set_device(0)
         | 
| 58 | 
            +
            model_dtype, torch_dtype = 'bf16', torch.bfloat16   # Use bf16, fp16 or fp32	
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            model = PyramidDiTForVideoGeneration(
         | 
| 61 | 
            +
                'PATH',                                         # The downloaded checkpoint dir
         | 
| 62 | 
            +
                model_dtype,
         | 
| 63 | 
            +
                model_variant='diffusion_transformer_768p',     # 'diffusion_transformer_384p'
         | 
| 64 | 
            +
            )
         | 
| 65 | 
            +
             | 
| 66 | 
            +
            model.vae.to("cuda")
         | 
| 67 | 
            +
            model.dit.to("cuda")
         | 
| 68 | 
            +
            model.text_encoder.to("cuda")
         | 
| 69 | 
            +
            model.vae.enable_tiling()
         | 
| 70 | 
            +
            ```
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            Then, you can try text-to-video generation on your own prompts:
         | 
| 73 | 
            +
             | 
| 74 | 
            +
            ```python
         | 
| 75 | 
            +
            prompt = "A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors"
         | 
| 76 | 
            +
             | 
| 77 | 
            +
            with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
         | 
| 78 | 
            +
                frames = model.generate(
         | 
| 79 | 
            +
                    prompt=prompt,
         | 
| 80 | 
            +
                    num_inference_steps=[20, 20, 20],
         | 
| 81 | 
            +
                    video_num_inference_steps=[10, 10, 10],
         | 
| 82 | 
            +
                    height=768,     
         | 
| 83 | 
            +
                    width=1280,
         | 
| 84 | 
            +
                    temp=16,                    # temp=16: 5s, temp=31: 10s
         | 
| 85 | 
            +
                    guidance_scale=9.0,         # The guidance for the first frame
         | 
| 86 | 
            +
                    video_guidance_scale=5.0,   # The guidance for the other video latent
         | 
| 87 | 
            +
                    output_type="pil",
         | 
| 88 | 
            +
                    save_memory=True,           # If you have enough GPU memory, set it to `False` to improve vae decoding speed
         | 
| 89 | 
            +
                )
         | 
| 90 | 
            +
             | 
| 91 | 
            +
            export_to_video(frames, "./text_to_video_sample.mp4", fps=24)
         | 
| 92 | 
            +
            ```
         | 
| 93 | 
            +
             | 
| 94 | 
            +
            As an autoregressive model, our model also supports (text conditioned) image-to-video generation:
         | 
| 95 | 
            +
             | 
| 96 | 
            +
            ```python
         | 
| 97 | 
            +
            image = Image.open('assets/the_great_wall.jpg').convert("RGB").resize((1280, 768))
         | 
| 98 | 
            +
            prompt = "FPV flying over the Great Wall"
         | 
| 99 | 
            +
             | 
| 100 | 
            +
            with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
         | 
| 101 | 
            +
                frames = model.generate_i2v(
         | 
| 102 | 
            +
                    prompt=prompt,
         | 
| 103 | 
            +
                    input_image=image,
         | 
| 104 | 
            +
                    num_inference_steps=[10, 10, 10],
         | 
| 105 | 
            +
                    temp=16,
         | 
| 106 | 
            +
                    video_guidance_scale=4.0,
         | 
| 107 | 
            +
                    output_type="pil",
         | 
| 108 | 
            +
                    save_memory=True,           # If you have enough GPU memory, set it to `False` to improve vae decoding speed
         | 
| 109 | 
            +
                )
         | 
| 110 | 
            +
             | 
| 111 | 
            +
            export_to_video(frames, "./image_to_video_sample.mp4", fps=24)
         | 
| 112 | 
            +
            ```
         | 
| 113 | 
            +
             | 
| 114 | 
            +
            Usage tips:
         | 
| 115 | 
            +
             | 
| 116 | 
            +
            * The `guidance_scale` parameter controls the visual quality. We suggest using a guidance within [7, 9] for the 768p checkpoint during text-to-video generation, and 7 for the 384p checkpoint.
         | 
| 117 | 
            +
            * The `video_guidance_scale` parameter controls the motion. A larger value increases the dynamic degree and mitigates the autoregressive generation degradation, while a smaller value stabilizes the video.
         | 
| 118 | 
            +
            * For 10-second video generation, we recommend using a guidance scale of 7 and a video guidance scale of 5.
         | 
| 119 | 
            +
             | 
| 120 | 
            +
            ## Gallery
         | 
| 121 | 
            +
             | 
| 122 | 
            +
            The following video examples are generated at 5s, 768p, 24fps. For more results, please visit our [project page](https://pyramid-flow.github.io).
         | 
| 123 | 
            +
             | 
| 124 | 
            +
            <table class="center" border="0" style="width: 100%; text-align: left;">
         | 
| 125 | 
            +
            <tr>
         | 
| 126 | 
            +
              <td><video src="https://github.com/user-attachments/assets/5b44a57e-fa08-4554-84a2-2c7a99f2b343" autoplay muted loop playsinline></video></td>
         | 
| 127 | 
            +
              <td><video src="https://github.com/user-attachments/assets/5afd5970-de72-40e2-900d-a20d18308e8e" autoplay muted loop playsinline></video></td>
         | 
| 128 | 
            +
            </tr>
         | 
| 129 | 
            +
            <tr>
         | 
| 130 | 
            +
              <td><video src="https://github.com/user-attachments/assets/1d44daf8-017f-40e9-bf18-1e19c0a8983b" autoplay muted loop playsinline></video></td>
         | 
| 131 | 
            +
              <td><video src="https://github.com/user-attachments/assets/7f5dd901-b7d7-48cc-b67a-3c5f9e1546d2" autoplay muted loop playsinline></video></td>
         | 
| 132 | 
            +
            </tr>
         | 
| 133 | 
            +
            </table>
         | 
| 134 | 
            +
             | 
| 135 | 
            +
            ## Comparison
         | 
| 136 | 
            +
             | 
| 137 | 
            +
            On VBench ([Huang et al., 2024](https://huggingface.co/spaces/Vchitect/VBench_Leaderboard)), our method surpasses all the compared open-source baselines. Even with only public video data, it achieves comparable performance to commercial models like Kling ([Kuaishou, 2024](https://kling.kuaishou.com/en)) and Gen-3 Alpha ([Runway, 2024](https://runwayml.com/research/introducing-gen-3-alpha)), especially in the quality score (84.74 vs. 84.11 of Gen-3) and motion smoothness.
         | 
| 138 | 
            +
             | 
| 139 | 
            +
            
         | 
| 140 | 
            +
             | 
| 141 | 
            +
            We conduct an additional user study with 20+ participants. As can be seen, our method is preferred over open-source models such as [Open-Sora](https://github.com/hpcaitech/Open-Sora) and [CogVideoX-2B](https://github.com/THUDM/CogVideo) especially in terms of motion smoothness.
         | 
| 142 | 
            +
             | 
| 143 | 
            +
            
         | 
| 144 | 
            +
             | 
| 145 | 
            +
            ## Acknowledgement
         | 
| 146 | 
            +
             | 
| 147 | 
            +
            We are grateful for the following awesome projects when implementing Pyramid Flow:
         | 
| 148 | 
            +
             | 
| 149 | 
            +
            * [SD3 Medium](https://huggingface.co/stabilityai/stable-diffusion-3-medium) and [Flux 1.0](https://huggingface.co/black-forest-labs/FLUX.1-dev): State-of-the-art image generation models based on flow matching.
         | 
| 150 | 
            +
            * [Diffusion Forcing](https://boyuan.space/diffusion-forcing) and [GameNGen](https://gamengen.github.io): Next-token prediction meets full-sequence diffusion.
         | 
| 151 | 
            +
            * [WebVid-10M](https://github.com/m-bain/webvid), [OpenVid-1M](https://github.com/NJU-PCALab/OpenVid-1M) and [Open-Sora Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan): Large-scale datasets for text-to-video generation.
         | 
| 152 | 
            +
            * [CogVideoX](https://github.com/THUDM/CogVideo): An open-source text-to-video generation model that shares many training details.
         | 
| 153 | 
            +
            * [Video-LLaMA2](https://github.com/DAMO-NLP-SG/VideoLLaMA2): An open-source video LLM for our video recaptioning.
         | 
| 154 | 
            +
             | 
| 155 | 
            +
            ## Citation
         | 
| 156 | 
            +
             | 
| 157 | 
            +
            Consider giving this repository a star and cite Pyramid Flow in your publications if it helps your research.
         | 
| 158 | 
            +
            ```
         | 
| 159 | 
            +
            @article{jin2024pyramidal,
         | 
| 160 | 
            +
              title={Pyramidal Flow Matching for Efficient Video Generative Modeling},
         | 
| 161 | 
            +
              author={Jin, Yang and Sun, Zhicheng and Li, Ningyuan and Xu, Kun and Xu, Kun and Jiang, Hao and Zhuang, Nan and Huang, Quzhe and Song, Yang and Mu, Yadong and Lin, Zhouchen},
         | 
| 162 | 
            +
              jounal={arXiv preprint arXiv:2410.05954},
         | 
| 163 | 
            +
              year={2024}
         | 
| 164 | 
            +
            }
         | 
| 165 | 
            +
            ```
         | 
    	
        assets/motivation.jpg
    ADDED
    
    |   | 
    	
        assets/the_great_wall.jpg
    ADDED
    
    |   | 
    	
        assets/user_study.jpg
    ADDED
    
    |   | 
    	
        assets/vbench.jpg
    ADDED
    
    |   | 
    	
        diffusion_schedulers/__init__.py
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .scheduling_cosine_ddpm import DDPMCosineScheduler
         | 
| 2 | 
            +
            from .scheduling_flow_matching import PyramidFlowMatchEulerDiscreteScheduler
         | 
    	
        diffusion_schedulers/scheduling_cosine_ddpm.py
    ADDED
    
    | @@ -0,0 +1,137 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            from dataclasses import dataclass
         | 
| 3 | 
            +
            from typing import List, Optional, Tuple, Union
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from diffusers.configuration_utils import ConfigMixin, register_to_config
         | 
| 8 | 
            +
            from diffusers.utils import BaseOutput
         | 
| 9 | 
            +
            from diffusers.utils.torch_utils import randn_tensor
         | 
| 10 | 
            +
            from diffusers.schedulers.scheduling_utils import SchedulerMixin
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            @dataclass
         | 
| 14 | 
            +
            class DDPMSchedulerOutput(BaseOutput):
         | 
| 15 | 
            +
                """
         | 
| 16 | 
            +
                Output class for the scheduler's step function output.
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                Args:
         | 
| 19 | 
            +
                    prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
         | 
| 20 | 
            +
                        Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
         | 
| 21 | 
            +
                        denoising loop.
         | 
| 22 | 
            +
                """
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                prev_sample: torch.Tensor
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            class DDPMCosineScheduler(SchedulerMixin, ConfigMixin):
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                @register_to_config
         | 
| 30 | 
            +
                def __init__(
         | 
| 31 | 
            +
                    self,
         | 
| 32 | 
            +
                    scaler: float = 1.0,
         | 
| 33 | 
            +
                    s: float = 0.008,
         | 
| 34 | 
            +
                ):
         | 
| 35 | 
            +
                    self.scaler = scaler
         | 
| 36 | 
            +
                    self.s = torch.tensor([s])
         | 
| 37 | 
            +
                    self._init_alpha_cumprod = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    # standard deviation of the initial noise distribution
         | 
| 40 | 
            +
                    self.init_noise_sigma = 1.0
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                def _alpha_cumprod(self, t, device):
         | 
| 43 | 
            +
                    if self.scaler > 1:
         | 
| 44 | 
            +
                        t = 1 - (1 - t) ** self.scaler
         | 
| 45 | 
            +
                    elif self.scaler < 1:
         | 
| 46 | 
            +
                        t = t**self.scaler
         | 
| 47 | 
            +
                    alpha_cumprod = torch.cos(
         | 
| 48 | 
            +
                        (t + self.s.to(device)) / (1 + self.s.to(device)) * torch.pi * 0.5
         | 
| 49 | 
            +
                    ) ** 2 / self._init_alpha_cumprod.to(device)
         | 
| 50 | 
            +
                    return alpha_cumprod.clamp(0.0001, 0.9999)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
         | 
| 53 | 
            +
                    """
         | 
| 54 | 
            +
                    Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
         | 
| 55 | 
            +
                    current timestep.
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                    Args:
         | 
| 58 | 
            +
                        sample (`torch.Tensor`): input sample
         | 
| 59 | 
            +
                        timestep (`int`, optional): current timestep
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    Returns:
         | 
| 62 | 
            +
                        `torch.Tensor`: scaled input sample
         | 
| 63 | 
            +
                    """
         | 
| 64 | 
            +
                    return sample
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                def set_timesteps(
         | 
| 67 | 
            +
                    self,
         | 
| 68 | 
            +
                    num_inference_steps: int = None,
         | 
| 69 | 
            +
                    timesteps: Optional[List[int]] = None,
         | 
| 70 | 
            +
                    device: Union[str, torch.device] = None,
         | 
| 71 | 
            +
                ):
         | 
| 72 | 
            +
                    """
         | 
| 73 | 
            +
                    Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    Args:
         | 
| 76 | 
            +
                        num_inference_steps (`Dict[float, int]`):
         | 
| 77 | 
            +
                            the number of diffusion steps used when generating samples with a pre-trained model. If passed, then
         | 
| 78 | 
            +
                            `timesteps` must be `None`.
         | 
| 79 | 
            +
                        device (`str` or `torch.device`, optional):
         | 
| 80 | 
            +
                            the device to which the timesteps are moved to. {2 / 3: 20, 0.0: 10}
         | 
| 81 | 
            +
                    """
         | 
| 82 | 
            +
                    if timesteps is None:
         | 
| 83 | 
            +
                        timesteps = torch.linspace(1.0, 0.0, num_inference_steps + 1, device=device)
         | 
| 84 | 
            +
                    if not isinstance(timesteps, torch.Tensor):
         | 
| 85 | 
            +
                        timesteps = torch.Tensor(timesteps).to(device)
         | 
| 86 | 
            +
                    self.timesteps = timesteps
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                def step(
         | 
| 89 | 
            +
                    self,
         | 
| 90 | 
            +
                    model_output: torch.Tensor,
         | 
| 91 | 
            +
                    timestep: int,
         | 
| 92 | 
            +
                    sample: torch.Tensor,
         | 
| 93 | 
            +
                    generator=None,
         | 
| 94 | 
            +
                    return_dict: bool = True,
         | 
| 95 | 
            +
                ) -> Union[DDPMSchedulerOutput, Tuple]:
         | 
| 96 | 
            +
                    dtype = model_output.dtype
         | 
| 97 | 
            +
                    device = model_output.device
         | 
| 98 | 
            +
                    t = timestep
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    prev_t = self.previous_timestep(t)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    alpha_cumprod = self._alpha_cumprod(t, device).view(t.size(0), *[1 for _ in sample.shape[1:]])
         | 
| 103 | 
            +
                    alpha_cumprod_prev = self._alpha_cumprod(prev_t, device).view(prev_t.size(0), *[1 for _ in sample.shape[1:]])
         | 
| 104 | 
            +
                    alpha = alpha_cumprod / alpha_cumprod_prev
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    mu = (1.0 / alpha).sqrt() * (sample - (1 - alpha) * model_output / (1 - alpha_cumprod).sqrt())
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    std_noise = randn_tensor(mu.shape, generator=generator, device=model_output.device, dtype=model_output.dtype)
         | 
| 109 | 
            +
                    std = ((1 - alpha) * (1.0 - alpha_cumprod_prev) / (1.0 - alpha_cumprod)).sqrt() * std_noise
         | 
| 110 | 
            +
                    pred = mu + std * (prev_t != 0).float().view(prev_t.size(0), *[1 for _ in sample.shape[1:]])
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    if not return_dict:
         | 
| 113 | 
            +
                        return (pred.to(dtype),)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    return DDPMSchedulerOutput(prev_sample=pred.to(dtype))
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                def add_noise(
         | 
| 118 | 
            +
                    self,
         | 
| 119 | 
            +
                    original_samples: torch.Tensor,
         | 
| 120 | 
            +
                    noise: torch.Tensor,
         | 
| 121 | 
            +
                    timesteps: torch.Tensor,
         | 
| 122 | 
            +
                ) -> torch.Tensor:
         | 
| 123 | 
            +
                    device = original_samples.device
         | 
| 124 | 
            +
                    dtype = original_samples.dtype
         | 
| 125 | 
            +
                    alpha_cumprod = self._alpha_cumprod(timesteps, device=device).view(
         | 
| 126 | 
            +
                        timesteps.size(0), *[1 for _ in original_samples.shape[1:]]
         | 
| 127 | 
            +
                    )
         | 
| 128 | 
            +
                    noisy_samples = alpha_cumprod.sqrt() * original_samples + (1 - alpha_cumprod).sqrt() * noise
         | 
| 129 | 
            +
                    return noisy_samples.to(dtype=dtype)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                def __len__(self):
         | 
| 132 | 
            +
                    return self.config.num_train_timesteps
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                def previous_timestep(self, timestep):
         | 
| 135 | 
            +
                    index = (self.timesteps - timestep[0]).abs().argmin().item()
         | 
| 136 | 
            +
                    prev_t = self.timesteps[index + 1][None].expand(timestep.shape[0])
         | 
| 137 | 
            +
                    return prev_t
         | 
    	
        diffusion_schedulers/scheduling_flow_matching.py
    ADDED
    
    | @@ -0,0 +1,298 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from dataclasses import dataclass
         | 
| 2 | 
            +
            from typing import Optional, Tuple, Union, List
         | 
| 3 | 
            +
            import math
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from diffusers.configuration_utils import ConfigMixin, register_to_config
         | 
| 8 | 
            +
            from diffusers.utils import BaseOutput, logging
         | 
| 9 | 
            +
            from diffusers.utils.torch_utils import randn_tensor
         | 
| 10 | 
            +
            from diffusers.schedulers.scheduling_utils import SchedulerMixin
         | 
| 11 | 
            +
            from IPython import embed
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            @dataclass
         | 
| 15 | 
            +
            class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
         | 
| 16 | 
            +
                """
         | 
| 17 | 
            +
                Output class for the scheduler's `step` function output.
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                Args:
         | 
| 20 | 
            +
                    prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
         | 
| 21 | 
            +
                        Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
         | 
| 22 | 
            +
                        denoising loop.
         | 
| 23 | 
            +
                """
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                prev_sample: torch.FloatTensor
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            class PyramidFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
         | 
| 29 | 
            +
                """
         | 
| 30 | 
            +
                Euler scheduler.
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
         | 
| 33 | 
            +
                methods the library implements for all schedulers such as loading and saving.
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                Args:
         | 
| 36 | 
            +
                    num_train_timesteps (`int`, defaults to 1000):
         | 
| 37 | 
            +
                        The number of diffusion steps to train the model.
         | 
| 38 | 
            +
                    timestep_spacing (`str`, defaults to `"linspace"`):
         | 
| 39 | 
            +
                        The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
         | 
| 40 | 
            +
                        Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
         | 
| 41 | 
            +
                    shift (`float`, defaults to 1.0):
         | 
| 42 | 
            +
                        The shift value for the timestep schedule.
         | 
| 43 | 
            +
                """
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                _compatibles = []
         | 
| 46 | 
            +
                order = 1
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                @register_to_config
         | 
| 49 | 
            +
                def __init__(
         | 
| 50 | 
            +
                    self,
         | 
| 51 | 
            +
                    num_train_timesteps: int = 1000,
         | 
| 52 | 
            +
                    shift: float = 1.0,     # Following Stable diffusion 3, 
         | 
| 53 | 
            +
                    stages: int = 3,
         | 
| 54 | 
            +
                    stage_range: List = [0, 1/3, 2/3, 1],
         | 
| 55 | 
            +
                    gamma: float = 1/3,
         | 
| 56 | 
            +
                ):
         | 
| 57 | 
            +
                    
         | 
| 58 | 
            +
                    self.timestep_ratios = {}           # The timestep ratio for each stage
         | 
| 59 | 
            +
                    self.timesteps_per_stage = {}       # The  detailed timesteps per stage
         | 
| 60 | 
            +
                    self.sigmas_per_stage = {}
         | 
| 61 | 
            +
                    self.start_sigmas = {}           
         | 
| 62 | 
            +
                    self.end_sigmas = {}
         | 
| 63 | 
            +
                    self.ori_start_sigmas = {}
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    # self.init_sigmas()
         | 
| 66 | 
            +
                    self.init_sigmas_for_each_stage()
         | 
| 67 | 
            +
                    self.sigma_min = self.sigmas[-1].item()
         | 
| 68 | 
            +
                    self.sigma_max = self.sigmas[0].item()
         | 
| 69 | 
            +
                    self.gamma = gamma
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                def init_sigmas(self):
         | 
| 72 | 
            +
                    """
         | 
| 73 | 
            +
                        initialize the global timesteps and sigmas
         | 
| 74 | 
            +
                    """
         | 
| 75 | 
            +
                    num_train_timesteps = self.config.num_train_timesteps
         | 
| 76 | 
            +
                    shift = self.config.shift
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
         | 
| 79 | 
            +
                    timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    sigmas = timesteps / num_train_timesteps
         | 
| 82 | 
            +
                    sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    self.timesteps = sigmas * num_train_timesteps
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    self._step_index = None
         | 
| 87 | 
            +
                    self._begin_index = None
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    self.sigmas = sigmas.to("cpu")  # to avoid too much CPU/GPU communication
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                def init_sigmas_for_each_stage(self):
         | 
| 92 | 
            +
                    """
         | 
| 93 | 
            +
                        Init the timesteps for each stage
         | 
| 94 | 
            +
                    """
         | 
| 95 | 
            +
                    self.init_sigmas()
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    stage_distance = []
         | 
| 98 | 
            +
                    stages = self.config.stages
         | 
| 99 | 
            +
                    training_steps = self.config.num_train_timesteps
         | 
| 100 | 
            +
                    stage_range = self.config.stage_range
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    # Init the start and end point of each stage
         | 
| 103 | 
            +
                    for i_s in range(stages):
         | 
| 104 | 
            +
                        # To decide the start and ends point
         | 
| 105 | 
            +
                        start_indice = int(stage_range[i_s] * training_steps)
         | 
| 106 | 
            +
                        start_indice = max(start_indice, 0)
         | 
| 107 | 
            +
                        end_indice = int(stage_range[i_s+1] * training_steps)
         | 
| 108 | 
            +
                        end_indice = min(end_indice, training_steps)
         | 
| 109 | 
            +
                        start_sigma = self.sigmas[start_indice].item()
         | 
| 110 | 
            +
                        end_sigma = self.sigmas[end_indice].item() if end_indice < training_steps else 0.0
         | 
| 111 | 
            +
                        self.ori_start_sigmas[i_s] = start_sigma
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                        if i_s != 0:
         | 
| 114 | 
            +
                            ori_sigma = 1 - start_sigma
         | 
| 115 | 
            +
                            gamma = self.config.gamma
         | 
| 116 | 
            +
                            corrected_sigma = (1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)) * ori_sigma
         | 
| 117 | 
            +
                            # corrected_sigma = 1 / (2 - ori_sigma) * ori_sigma
         | 
| 118 | 
            +
                            start_sigma = 1 - corrected_sigma
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                        stage_distance.append(start_sigma - end_sigma)
         | 
| 121 | 
            +
                        self.start_sigmas[i_s] = start_sigma
         | 
| 122 | 
            +
                        self.end_sigmas[i_s] = end_sigma
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    # Determine the ratio of each stage according to flow length
         | 
| 125 | 
            +
                    tot_distance = sum(stage_distance)
         | 
| 126 | 
            +
                    for i_s in range(stages):
         | 
| 127 | 
            +
                        if i_s == 0:
         | 
| 128 | 
            +
                            start_ratio = 0.0
         | 
| 129 | 
            +
                        else:
         | 
| 130 | 
            +
                            start_ratio = sum(stage_distance[:i_s]) / tot_distance
         | 
| 131 | 
            +
                        if i_s == stages - 1:
         | 
| 132 | 
            +
                            end_ratio = 1.0
         | 
| 133 | 
            +
                        else:
         | 
| 134 | 
            +
                            end_ratio = sum(stage_distance[:i_s+1]) / tot_distance
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                        self.timestep_ratios[i_s] = (start_ratio, end_ratio)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    # Determine the timesteps and sigmas for each stage
         | 
| 139 | 
            +
                    for i_s in range(stages):
         | 
| 140 | 
            +
                        timestep_ratio = self.timestep_ratios[i_s]
         | 
| 141 | 
            +
                        timestep_max = self.timesteps[int(timestep_ratio[0] * training_steps)]
         | 
| 142 | 
            +
                        timestep_min = self.timesteps[min(int(timestep_ratio[1] * training_steps), training_steps - 1)]
         | 
| 143 | 
            +
                        timesteps = np.linspace(
         | 
| 144 | 
            +
                            timestep_max, timestep_min, training_steps + 1,
         | 
| 145 | 
            +
                        )
         | 
| 146 | 
            +
                        self.timesteps_per_stage[i_s] = torch.from_numpy(timesteps[:-1])
         | 
| 147 | 
            +
                        stage_sigmas = np.linspace(
         | 
| 148 | 
            +
                            1, 0, training_steps + 1,
         | 
| 149 | 
            +
                        )
         | 
| 150 | 
            +
                        self.sigmas_per_stage[i_s] = torch.from_numpy(stage_sigmas[:-1])
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                @property
         | 
| 153 | 
            +
                def step_index(self):
         | 
| 154 | 
            +
                    """
         | 
| 155 | 
            +
                    The index counter for current timestep. It will increase 1 after each scheduler step.
         | 
| 156 | 
            +
                    """
         | 
| 157 | 
            +
                    return self._step_index
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                @property
         | 
| 160 | 
            +
                def begin_index(self):
         | 
| 161 | 
            +
                    """
         | 
| 162 | 
            +
                    The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
         | 
| 163 | 
            +
                    """
         | 
| 164 | 
            +
                    return self._begin_index
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
         | 
| 167 | 
            +
                def set_begin_index(self, begin_index: int = 0):
         | 
| 168 | 
            +
                    """
         | 
| 169 | 
            +
                    Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    Args:
         | 
| 172 | 
            +
                        begin_index (`int`):
         | 
| 173 | 
            +
                            The begin index for the scheduler.
         | 
| 174 | 
            +
                    """
         | 
| 175 | 
            +
                    self._begin_index = begin_index
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                def _sigma_to_t(self, sigma):
         | 
| 178 | 
            +
                    return sigma * self.config.num_train_timesteps
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                def set_timesteps(self, num_inference_steps: int, stage_index: int, device: Union[str, torch.device] = None):
         | 
| 181 | 
            +
                    """
         | 
| 182 | 
            +
                        Setting the timesteps and sigmas for each stage 
         | 
| 183 | 
            +
                    """
         | 
| 184 | 
            +
                    self.num_inference_steps = num_inference_steps
         | 
| 185 | 
            +
                    training_steps = self.config.num_train_timesteps     
         | 
| 186 | 
            +
                    self.init_sigmas()
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    stage_timesteps = self.timesteps_per_stage[stage_index]
         | 
| 189 | 
            +
                    timestep_max = stage_timesteps[0].item()
         | 
| 190 | 
            +
                    timestep_min = stage_timesteps[-1].item()
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    timesteps = np.linspace(
         | 
| 193 | 
            +
                        timestep_max, timestep_min, num_inference_steps,
         | 
| 194 | 
            +
                    )
         | 
| 195 | 
            +
                    self.timesteps = torch.from_numpy(timesteps).to(device=device)
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    stage_sigmas = self.sigmas_per_stage[stage_index]
         | 
| 198 | 
            +
                    sigma_max = stage_sigmas[0].item()
         | 
| 199 | 
            +
                    sigma_min = stage_sigmas[-1].item()
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    ratios = np.linspace(
         | 
| 202 | 
            +
                        sigma_max, sigma_min, num_inference_steps
         | 
| 203 | 
            +
                    )
         | 
| 204 | 
            +
                    sigmas = torch.from_numpy(ratios).to(device=device)
         | 
| 205 | 
            +
                    self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    self._step_index = None
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                def index_for_timestep(self, timestep, schedule_timesteps=None):
         | 
| 210 | 
            +
                    if schedule_timesteps is None:
         | 
| 211 | 
            +
                        schedule_timesteps = self.timesteps
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    indices = (schedule_timesteps == timestep).nonzero()
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    # The sigma index that is taken for the **very** first `step`
         | 
| 216 | 
            +
                    # is always the second index (or the last index if there is only 1)
         | 
| 217 | 
            +
                    # This way we can ensure we don't accidentally skip a sigma in
         | 
| 218 | 
            +
                    # case we start in the middle of the denoising schedule (e.g. for image-to-image)
         | 
| 219 | 
            +
                    pos = 1 if len(indices) > 1 else 0
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                    return indices[pos].item()
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                def _init_step_index(self, timestep):
         | 
| 224 | 
            +
                    if self.begin_index is None:
         | 
| 225 | 
            +
                        if isinstance(timestep, torch.Tensor):
         | 
| 226 | 
            +
                            timestep = timestep.to(self.timesteps.device)
         | 
| 227 | 
            +
                        self._step_index = self.index_for_timestep(timestep)
         | 
| 228 | 
            +
                    else:
         | 
| 229 | 
            +
                        self._step_index = self._begin_index
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                def step(
         | 
| 232 | 
            +
                    self,
         | 
| 233 | 
            +
                    model_output: torch.FloatTensor,
         | 
| 234 | 
            +
                    timestep: Union[float, torch.FloatTensor],
         | 
| 235 | 
            +
                    sample: torch.FloatTensor,
         | 
| 236 | 
            +
                    generator: Optional[torch.Generator] = None,
         | 
| 237 | 
            +
                    return_dict: bool = True,
         | 
| 238 | 
            +
                ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
         | 
| 239 | 
            +
                    """
         | 
| 240 | 
            +
                    Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
         | 
| 241 | 
            +
                    process from the learned model outputs (most often the predicted noise).
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                    Args:
         | 
| 244 | 
            +
                        model_output (`torch.FloatTensor`):
         | 
| 245 | 
            +
                            The direct output from learned diffusion model.
         | 
| 246 | 
            +
                        timestep (`float`):
         | 
| 247 | 
            +
                            The current discrete timestep in the diffusion chain.
         | 
| 248 | 
            +
                        sample (`torch.FloatTensor`):
         | 
| 249 | 
            +
                            A current instance of a sample created by the diffusion process.
         | 
| 250 | 
            +
                        generator (`torch.Generator`, *optional*):
         | 
| 251 | 
            +
                            A random number generator.
         | 
| 252 | 
            +
                        return_dict (`bool`):
         | 
| 253 | 
            +
                            Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
         | 
| 254 | 
            +
                            tuple.
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    Returns:
         | 
| 257 | 
            +
                        [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
         | 
| 258 | 
            +
                            If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
         | 
| 259 | 
            +
                            returned, otherwise a tuple is returned where the first element is the sample tensor.
         | 
| 260 | 
            +
                    """
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                    if (
         | 
| 263 | 
            +
                        isinstance(timestep, int)
         | 
| 264 | 
            +
                        or isinstance(timestep, torch.IntTensor)
         | 
| 265 | 
            +
                        or isinstance(timestep, torch.LongTensor)
         | 
| 266 | 
            +
                    ):
         | 
| 267 | 
            +
                        raise ValueError(
         | 
| 268 | 
            +
                            (
         | 
| 269 | 
            +
                                "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
         | 
| 270 | 
            +
                                " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
         | 
| 271 | 
            +
                                " one of the `scheduler.timesteps` as a timestep."
         | 
| 272 | 
            +
                            ),
         | 
| 273 | 
            +
                        )
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                    if self.step_index is None:
         | 
| 276 | 
            +
                        self._step_index = 0
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                    # Upcast to avoid precision issues when computing prev_sample
         | 
| 279 | 
            +
                    sample = sample.to(torch.float32)
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                    sigma = self.sigmas[self.step_index]
         | 
| 282 | 
            +
                    sigma_next = self.sigmas[self.step_index + 1]
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                    prev_sample = sample + (sigma_next - sigma) * model_output
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    # Cast sample back to model compatible dtype
         | 
| 287 | 
            +
                    prev_sample = prev_sample.to(model_output.dtype)
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                    # upon completion increase step index by one
         | 
| 290 | 
            +
                    self._step_index += 1
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                    if not return_dict:
         | 
| 293 | 
            +
                        return (prev_sample,)
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                    return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                def __len__(self):
         | 
| 298 | 
            +
                    return self.config.num_train_timesteps
         | 
    	
        pyramid_dit/__init__.py
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .modeling_pyramid_mmdit import PyramidDiffusionMMDiT
         | 
| 2 | 
            +
            from .pyramid_dit_for_video_gen_pipeline import PyramidDiTForVideoGeneration
         | 
| 3 | 
            +
            from .modeling_text_encoder import SD3TextEncoderWithMask
         | 
    	
        pyramid_dit/modeling_embedding.py
    ADDED
    
    | @@ -0,0 +1,390 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Any, Dict, Optional, Union
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            import math
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from diffusers.models.activations import get_activation
         | 
| 9 | 
            +
            from einops import rearrange
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def get_1d_sincos_pos_embed(
         | 
| 13 | 
            +
                embed_dim, num_frames, cls_token=False, extra_tokens=0,
         | 
| 14 | 
            +
            ):
         | 
| 15 | 
            +
                t = np.arange(num_frames, dtype=np.float32)
         | 
| 16 | 
            +
                pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, t)  # (T, D)
         | 
| 17 | 
            +
                if cls_token and extra_tokens > 0:
         | 
| 18 | 
            +
                    pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
         | 
| 19 | 
            +
                return pos_embed
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def get_2d_sincos_pos_embed(
         | 
| 23 | 
            +
                embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
         | 
| 24 | 
            +
            ):
         | 
| 25 | 
            +
                """
         | 
| 26 | 
            +
                grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
         | 
| 27 | 
            +
                [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
         | 
| 28 | 
            +
                """
         | 
| 29 | 
            +
                if isinstance(grid_size, int):
         | 
| 30 | 
            +
                    grid_size = (grid_size, grid_size)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
         | 
| 33 | 
            +
                grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
         | 
| 34 | 
            +
                grid = np.meshgrid(grid_w, grid_h)  # here w goes first
         | 
| 35 | 
            +
                grid = np.stack(grid, axis=0)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
         | 
| 38 | 
            +
                pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
         | 
| 39 | 
            +
                if cls_token and extra_tokens > 0:
         | 
| 40 | 
            +
                    pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
         | 
| 41 | 
            +
                return pos_embed
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
         | 
| 45 | 
            +
                if embed_dim % 2 != 0:
         | 
| 46 | 
            +
                    raise ValueError("embed_dim must be divisible by 2")
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                # use half of dimensions to encode grid_h
         | 
| 49 | 
            +
                emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
         | 
| 50 | 
            +
                emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
         | 
| 53 | 
            +
                return emb
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
         | 
| 57 | 
            +
                """
         | 
| 58 | 
            +
                embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
         | 
| 59 | 
            +
                """
         | 
| 60 | 
            +
                if embed_dim % 2 != 0:
         | 
| 61 | 
            +
                    raise ValueError("embed_dim must be divisible by 2")
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                omega = np.arange(embed_dim // 2, dtype=np.float64)
         | 
| 64 | 
            +
                omega /= embed_dim / 2.0
         | 
| 65 | 
            +
                omega = 1.0 / 10000**omega  # (D/2,)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                pos = pos.reshape(-1)  # (M,)
         | 
| 68 | 
            +
                out = np.einsum("m,d->md", pos, omega)  # (M, D/2), outer product
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                emb_sin = np.sin(out)  # (M, D/2)
         | 
| 71 | 
            +
                emb_cos = np.cos(out)  # (M, D/2)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
         | 
| 74 | 
            +
                return emb
         | 
| 75 | 
            +
             | 
| 76 | 
            +
             | 
| 77 | 
            +
            def get_timestep_embedding(
         | 
| 78 | 
            +
                timesteps: torch.Tensor,
         | 
| 79 | 
            +
                embedding_dim: int,
         | 
| 80 | 
            +
                flip_sin_to_cos: bool = False,
         | 
| 81 | 
            +
                downscale_freq_shift: float = 1,
         | 
| 82 | 
            +
                scale: float = 1,
         | 
| 83 | 
            +
                max_period: int = 10000,
         | 
| 84 | 
            +
            ):
         | 
| 85 | 
            +
                """
         | 
| 86 | 
            +
                This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
         | 
| 87 | 
            +
                :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional.
         | 
| 88 | 
            +
                :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
         | 
| 89 | 
            +
                embeddings. :return: an [N x dim] Tensor of positional embeddings.
         | 
| 90 | 
            +
                """
         | 
| 91 | 
            +
                assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                half_dim = embedding_dim // 2
         | 
| 94 | 
            +
                exponent = -math.log(max_period) * torch.arange(
         | 
| 95 | 
            +
                    start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
         | 
| 96 | 
            +
                )
         | 
| 97 | 
            +
                exponent = exponent / (half_dim - downscale_freq_shift)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                emb = torch.exp(exponent)
         | 
| 100 | 
            +
                emb = timesteps[:, None].float() * emb[None, :]
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                # scale embeddings
         | 
| 103 | 
            +
                emb = scale * emb
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                # concat sine and cosine embeddings
         | 
| 106 | 
            +
                emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                # flip sine and cosine embeddings
         | 
| 109 | 
            +
                if flip_sin_to_cos:
         | 
| 110 | 
            +
                    emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                # zero pad
         | 
| 113 | 
            +
                if embedding_dim % 2 == 1:
         | 
| 114 | 
            +
                    emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
         | 
| 115 | 
            +
                return emb
         | 
| 116 | 
            +
             | 
| 117 | 
            +
             | 
| 118 | 
            +
            class Timesteps(nn.Module):
         | 
| 119 | 
            +
                def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
         | 
| 120 | 
            +
                    super().__init__()
         | 
| 121 | 
            +
                    self.num_channels = num_channels
         | 
| 122 | 
            +
                    self.flip_sin_to_cos = flip_sin_to_cos
         | 
| 123 | 
            +
                    self.downscale_freq_shift = downscale_freq_shift
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                def forward(self, timesteps):
         | 
| 126 | 
            +
                    t_emb = get_timestep_embedding(
         | 
| 127 | 
            +
                        timesteps,
         | 
| 128 | 
            +
                        self.num_channels,
         | 
| 129 | 
            +
                        flip_sin_to_cos=self.flip_sin_to_cos,
         | 
| 130 | 
            +
                        downscale_freq_shift=self.downscale_freq_shift,
         | 
| 131 | 
            +
                    )
         | 
| 132 | 
            +
                    return t_emb
         | 
| 133 | 
            +
             | 
| 134 | 
            +
             | 
| 135 | 
            +
            class TimestepEmbedding(nn.Module):
         | 
| 136 | 
            +
                def __init__(
         | 
| 137 | 
            +
                    self,
         | 
| 138 | 
            +
                    in_channels: int,
         | 
| 139 | 
            +
                    time_embed_dim: int,
         | 
| 140 | 
            +
                    act_fn: str = "silu",
         | 
| 141 | 
            +
                    out_dim: int = None,
         | 
| 142 | 
            +
                    post_act_fn: Optional[str] = None,
         | 
| 143 | 
            +
                    sample_proj_bias=True,
         | 
| 144 | 
            +
                ):
         | 
| 145 | 
            +
                    super().__init__()
         | 
| 146 | 
            +
                    self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
         | 
| 147 | 
            +
                    self.act = get_activation(act_fn)
         | 
| 148 | 
            +
                    self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, sample_proj_bias)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                def forward(self, sample):
         | 
| 151 | 
            +
                    sample = self.linear_1(sample)
         | 
| 152 | 
            +
                    sample = self.act(sample)
         | 
| 153 | 
            +
                    sample = self.linear_2(sample)
         | 
| 154 | 
            +
                    return sample
         | 
| 155 | 
            +
             | 
| 156 | 
            +
             | 
| 157 | 
            +
            class TextProjection(nn.Module):
         | 
| 158 | 
            +
                def __init__(self, in_features, hidden_size, act_fn="silu"):
         | 
| 159 | 
            +
                    super().__init__()
         | 
| 160 | 
            +
                    self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
         | 
| 161 | 
            +
                    self.act_1 = get_activation(act_fn)
         | 
| 162 | 
            +
                    self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                def forward(self, caption):
         | 
| 165 | 
            +
                    hidden_states = self.linear_1(caption)
         | 
| 166 | 
            +
                    hidden_states = self.act_1(hidden_states)
         | 
| 167 | 
            +
                    hidden_states = self.linear_2(hidden_states)
         | 
| 168 | 
            +
                    return hidden_states
         | 
| 169 | 
            +
             | 
| 170 | 
            +
             | 
| 171 | 
            +
            class CombinedTimestepConditionEmbeddings(nn.Module):
         | 
| 172 | 
            +
                def __init__(self, embedding_dim, pooled_projection_dim):
         | 
| 173 | 
            +
                    super().__init__()
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
         | 
| 176 | 
            +
                    self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
         | 
| 177 | 
            +
                    self.text_embedder = TextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                def forward(self, timestep, pooled_projection):
         | 
| 180 | 
            +
                    timesteps_proj = self.time_proj(timestep)
         | 
| 181 | 
            +
                    timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))  # (N, D)
         | 
| 182 | 
            +
                    pooled_projections = self.text_embedder(pooled_projection)
         | 
| 183 | 
            +
                    conditioning = timesteps_emb + pooled_projections
         | 
| 184 | 
            +
                    return conditioning
         | 
| 185 | 
            +
             | 
| 186 | 
            +
             | 
| 187 | 
            +
            class CombinedTimestepEmbeddings(nn.Module):
         | 
| 188 | 
            +
                def __init__(self, embedding_dim):
         | 
| 189 | 
            +
                    super().__init__()
         | 
| 190 | 
            +
                    self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
         | 
| 191 | 
            +
                    self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                def forward(self, timestep):
         | 
| 194 | 
            +
                    timesteps_proj = self.time_proj(timestep)
         | 
| 195 | 
            +
                    timesteps_emb = self.timestep_embedder(timesteps_proj)  # (N, D)
         | 
| 196 | 
            +
                    return timesteps_emb
         | 
| 197 | 
            +
             | 
| 198 | 
            +
             | 
| 199 | 
            +
            class PatchEmbed3D(nn.Module):
         | 
| 200 | 
            +
                """Support the 3D Tensor input"""
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                def __init__(
         | 
| 203 | 
            +
                    self,
         | 
| 204 | 
            +
                    height=128,
         | 
| 205 | 
            +
                    width=128,
         | 
| 206 | 
            +
                    patch_size=2,
         | 
| 207 | 
            +
                    in_channels=16,
         | 
| 208 | 
            +
                    embed_dim=1536,
         | 
| 209 | 
            +
                    layer_norm=False,
         | 
| 210 | 
            +
                    bias=True,
         | 
| 211 | 
            +
                    interpolation_scale=1,
         | 
| 212 | 
            +
                    pos_embed_type="sincos",
         | 
| 213 | 
            +
                    temp_pos_embed_type='rope',
         | 
| 214 | 
            +
                    pos_embed_max_size=192,   # For SD3 cropping
         | 
| 215 | 
            +
                    max_num_frames=64,
         | 
| 216 | 
            +
                    add_temp_pos_embed=False,
         | 
| 217 | 
            +
                    interp_condition_pos=False,
         | 
| 218 | 
            +
                ):
         | 
| 219 | 
            +
                    super().__init__()
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                    num_patches = (height // patch_size) * (width // patch_size)
         | 
| 222 | 
            +
                    self.layer_norm = layer_norm
         | 
| 223 | 
            +
                    self.pos_embed_max_size = pos_embed_max_size
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    self.proj = nn.Conv2d(
         | 
| 226 | 
            +
                        in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
         | 
| 227 | 
            +
                    )
         | 
| 228 | 
            +
                    if layer_norm:
         | 
| 229 | 
            +
                        self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
         | 
| 230 | 
            +
                    else:
         | 
| 231 | 
            +
                        self.norm = None
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    self.patch_size = patch_size
         | 
| 234 | 
            +
                    self.height, self.width = height // patch_size, width // patch_size
         | 
| 235 | 
            +
                    self.base_size = height // patch_size
         | 
| 236 | 
            +
                    self.interpolation_scale = interpolation_scale
         | 
| 237 | 
            +
                    self.add_temp_pos_embed = add_temp_pos_embed
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                    # Calculate positional embeddings based on max size or default
         | 
| 240 | 
            +
                    if pos_embed_max_size:
         | 
| 241 | 
            +
                        grid_size = pos_embed_max_size
         | 
| 242 | 
            +
                    else:
         | 
| 243 | 
            +
                        grid_size = int(num_patches**0.5)
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                    if pos_embed_type is None:
         | 
| 246 | 
            +
                        self.pos_embed = None
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    elif pos_embed_type == "sincos":
         | 
| 249 | 
            +
                        pos_embed = get_2d_sincos_pos_embed(
         | 
| 250 | 
            +
                            embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale
         | 
| 251 | 
            +
                        )
         | 
| 252 | 
            +
                        persistent = True if pos_embed_max_size else False
         | 
| 253 | 
            +
                        self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent)
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                        if add_temp_pos_embed and temp_pos_embed_type == 'sincos':
         | 
| 256 | 
            +
                            time_pos_embed = get_1d_sincos_pos_embed(embed_dim, max_num_frames)
         | 
| 257 | 
            +
                            self.register_buffer("temp_pos_embed", torch.from_numpy(time_pos_embed).float().unsqueeze(0), persistent=True)
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                    elif pos_embed_type == "rope":
         | 
| 260 | 
            +
                        print("Using the rotary position embedding")
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                    else:
         | 
| 263 | 
            +
                        raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                    self.pos_embed_type = pos_embed_type
         | 
| 266 | 
            +
                    self.temp_pos_embed_type = temp_pos_embed_type
         | 
| 267 | 
            +
                    self.interp_condition_pos = interp_condition_pos
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                def cropped_pos_embed(self, height, width, ori_height, ori_width):
         | 
| 270 | 
            +
                    """Crops positional embeddings for SD3 compatibility."""
         | 
| 271 | 
            +
                    if self.pos_embed_max_size is None:
         | 
| 272 | 
            +
                        raise ValueError("`pos_embed_max_size` must be set for cropping.")
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                    height = height // self.patch_size
         | 
| 275 | 
            +
                    width = width // self.patch_size
         | 
| 276 | 
            +
                    ori_height = ori_height // self.patch_size
         | 
| 277 | 
            +
                    ori_width = ori_width // self.patch_size
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                    assert ori_height >= height, "The ori_height needs >= height"
         | 
| 280 | 
            +
                    assert ori_width >= width, "The ori_width needs >= width"
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    if height > self.pos_embed_max_size:
         | 
| 283 | 
            +
                        raise ValueError(
         | 
| 284 | 
            +
                            f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
         | 
| 285 | 
            +
                        )
         | 
| 286 | 
            +
                    if width > self.pos_embed_max_size:
         | 
| 287 | 
            +
                        raise ValueError(
         | 
| 288 | 
            +
                            f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
         | 
| 289 | 
            +
                        )
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                    if self.interp_condition_pos:
         | 
| 292 | 
            +
                        top = (self.pos_embed_max_size - ori_height) // 2
         | 
| 293 | 
            +
                        left = (self.pos_embed_max_size - ori_width) // 2
         | 
| 294 | 
            +
                        spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
         | 
| 295 | 
            +
                        spatial_pos_embed = spatial_pos_embed[:, top : top + ori_height, left : left + ori_width, :]   # [b h w c]
         | 
| 296 | 
            +
                        if ori_height != height or ori_width != width:
         | 
| 297 | 
            +
                            spatial_pos_embed = spatial_pos_embed.permute(0, 3, 1, 2)
         | 
| 298 | 
            +
                            spatial_pos_embed = torch.nn.functional.interpolate(spatial_pos_embed, size=(height, width), mode='bilinear')
         | 
| 299 | 
            +
                            spatial_pos_embed = spatial_pos_embed.permute(0, 2, 3, 1)
         | 
| 300 | 
            +
                    else:
         | 
| 301 | 
            +
                        top = (self.pos_embed_max_size - height) // 2
         | 
| 302 | 
            +
                        left = (self.pos_embed_max_size - width) // 2
         | 
| 303 | 
            +
                        spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
         | 
| 304 | 
            +
                        spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
         | 
| 305 | 
            +
                    
         | 
| 306 | 
            +
                    spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                    return spatial_pos_embed
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                def forward_func(self, latent, time_index=0, ori_height=None, ori_width=None):
         | 
| 311 | 
            +
                    if self.pos_embed_max_size is not None:
         | 
| 312 | 
            +
                        height, width = latent.shape[-2:]
         | 
| 313 | 
            +
                    else:
         | 
| 314 | 
            +
                        height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                    bs = latent.shape[0]
         | 
| 317 | 
            +
                    temp = latent.shape[2]
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                    latent = rearrange(latent, 'b c t h w -> (b t) c h w')
         | 
| 320 | 
            +
                    latent = self.proj(latent)
         | 
| 321 | 
            +
                    latent = latent.flatten(2).transpose(1, 2)  # (BT)CHW -> (BT)NC
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                    if self.layer_norm:
         | 
| 324 | 
            +
                        latent = self.norm(latent)
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                    if self.pos_embed_type == 'sincos':
         | 
| 327 | 
            +
                        # Spatial position embedding, Interpolate or crop positional embeddings as needed
         | 
| 328 | 
            +
                        if self.pos_embed_max_size:
         | 
| 329 | 
            +
                            pos_embed = self.cropped_pos_embed(height, width, ori_height, ori_width)
         | 
| 330 | 
            +
                        else:
         | 
| 331 | 
            +
                            raise NotImplementedError("Not implemented sincos pos embed without sd3 max pos crop")
         | 
| 332 | 
            +
                            if self.height != height or self.width != width:
         | 
| 333 | 
            +
                                pos_embed = get_2d_sincos_pos_embed(
         | 
| 334 | 
            +
                                    embed_dim=self.pos_embed.shape[-1],
         | 
| 335 | 
            +
                                    grid_size=(height, width),
         | 
| 336 | 
            +
                                    base_size=self.base_size,
         | 
| 337 | 
            +
                                    interpolation_scale=self.interpolation_scale,
         | 
| 338 | 
            +
                                )
         | 
| 339 | 
            +
                                pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
         | 
| 340 | 
            +
                            else:
         | 
| 341 | 
            +
                                pos_embed = self.pos_embed
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                        if self.add_temp_pos_embed and self.temp_pos_embed_type == 'sincos':
         | 
| 344 | 
            +
                            latent_dtype = latent.dtype
         | 
| 345 | 
            +
                            latent = latent + pos_embed
         | 
| 346 | 
            +
                            latent = rearrange(latent, '(b t) n c -> (b n) t c', t=temp)
         | 
| 347 | 
            +
                            latent = latent + self.temp_pos_embed[:, time_index:time_index + temp, :]
         | 
| 348 | 
            +
                            latent = latent.to(latent_dtype)
         | 
| 349 | 
            +
                            latent = rearrange(latent, '(b n) t c -> b t n c', b=bs)
         | 
| 350 | 
            +
                        else:
         | 
| 351 | 
            +
                            latent = (latent + pos_embed).to(latent.dtype)
         | 
| 352 | 
            +
                            latent = rearrange(latent, '(b t) n c -> b t n c', b=bs, t=temp)
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                    else:
         | 
| 355 | 
            +
                        assert self.pos_embed_type == "rope", "Only supporting the sincos and rope embedding"
         | 
| 356 | 
            +
                        latent = rearrange(latent, '(b t) n c -> b t n c', b=bs, t=temp)
         | 
| 357 | 
            +
                    
         | 
| 358 | 
            +
                    return latent
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                def forward(self, latent):
         | 
| 361 | 
            +
                    """
         | 
| 362 | 
            +
                    Arguments:
         | 
| 363 | 
            +
                        past_condition_latents (Torch.FloatTensor): The past latent during the generation
         | 
| 364 | 
            +
                        flatten_input (bool): True indicate flatten the latent into 1D sequence
         | 
| 365 | 
            +
                    """
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                    if isinstance(latent, list):
         | 
| 368 | 
            +
                        output_list = []
         | 
| 369 | 
            +
                        
         | 
| 370 | 
            +
                        for latent_ in latent:
         | 
| 371 | 
            +
                            if not isinstance(latent_, list):
         | 
| 372 | 
            +
                                latent_ = [latent_]
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                            output_latent = []
         | 
| 375 | 
            +
                            time_index = 0
         | 
| 376 | 
            +
                            ori_height, ori_width = latent_[-1].shape[-2:]
         | 
| 377 | 
            +
                            for each_latent in latent_:
         | 
| 378 | 
            +
                                hidden_state = self.forward_func(each_latent, time_index=time_index, ori_height=ori_height, ori_width=ori_width)
         | 
| 379 | 
            +
                                time_index += each_latent.shape[2]
         | 
| 380 | 
            +
                                hidden_state = rearrange(hidden_state, "b t n c -> b (t n) c")
         | 
| 381 | 
            +
                                output_latent.append(hidden_state)
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                            output_latent = torch.cat(output_latent, dim=1)
         | 
| 384 | 
            +
                            output_list.append(output_latent)
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                        return output_list
         | 
| 387 | 
            +
                    else:
         | 
| 388 | 
            +
                        hidden_states = self.forward_func(latent)
         | 
| 389 | 
            +
                        hidden_states = rearrange(hidden_states, "b t n c -> b (t n) c")
         | 
| 390 | 
            +
                        return hidden_states
         | 
    	
        pyramid_dit/modeling_mmdit_block.py
    ADDED
    
    | @@ -0,0 +1,672 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Dict, Optional, Tuple, List
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.nn as nn
         | 
| 4 | 
            +
            import torch.nn.functional as F
         | 
| 5 | 
            +
            from einops import rearrange
         | 
| 6 | 
            +
            from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            try:
         | 
| 9 | 
            +
                from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
         | 
| 10 | 
            +
                from flash_attn.bert_padding import pad_input, unpad_input, index_first_axis
         | 
| 11 | 
            +
                from flash_attn.flash_attn_interface import flash_attn_varlen_func
         | 
| 12 | 
            +
            except:
         | 
| 13 | 
            +
                flash_attn_func = None
         | 
| 14 | 
            +
                flash_attn_qkvpacked_func = None
         | 
| 15 | 
            +
                flash_attn_varlen_func = None
         | 
| 16 | 
            +
                print("Please install flash attention")
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from trainer_misc import (
         | 
| 19 | 
            +
                is_sequence_parallel_initialized,
         | 
| 20 | 
            +
                get_sequence_parallel_group,
         | 
| 21 | 
            +
                get_sequence_parallel_world_size,
         | 
| 22 | 
            +
                all_to_all,
         | 
| 23 | 
            +
            )
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            from .modeling_normalization import AdaLayerNormZero, AdaLayerNormContinuous, RMSNorm
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            class FeedForward(nn.Module):
         | 
| 29 | 
            +
                r"""
         | 
| 30 | 
            +
                A feed-forward layer.
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                Parameters:
         | 
| 33 | 
            +
                    dim (`int`): The number of channels in the input.
         | 
| 34 | 
            +
                    dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
         | 
| 35 | 
            +
                    mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
         | 
| 36 | 
            +
                    dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
         | 
| 37 | 
            +
                    activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
         | 
| 38 | 
            +
                    final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
         | 
| 39 | 
            +
                    bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
         | 
| 40 | 
            +
                """
         | 
| 41 | 
            +
                def __init__(
         | 
| 42 | 
            +
                    self,
         | 
| 43 | 
            +
                    dim: int,
         | 
| 44 | 
            +
                    dim_out: Optional[int] = None,
         | 
| 45 | 
            +
                    mult: int = 4,
         | 
| 46 | 
            +
                    dropout: float = 0.0,
         | 
| 47 | 
            +
                    activation_fn: str = "geglu",
         | 
| 48 | 
            +
                    final_dropout: bool = False,
         | 
| 49 | 
            +
                    inner_dim=None,
         | 
| 50 | 
            +
                    bias: bool = True,
         | 
| 51 | 
            +
                ):
         | 
| 52 | 
            +
                    super().__init__()
         | 
| 53 | 
            +
                    if inner_dim is None:
         | 
| 54 | 
            +
                        inner_dim = int(dim * mult)
         | 
| 55 | 
            +
                    dim_out = dim_out if dim_out is not None else dim
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                    if activation_fn == "gelu":
         | 
| 58 | 
            +
                        act_fn = GELU(dim, inner_dim, bias=bias)
         | 
| 59 | 
            +
                    if activation_fn == "gelu-approximate":
         | 
| 60 | 
            +
                        act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
         | 
| 61 | 
            +
                    elif activation_fn == "geglu":
         | 
| 62 | 
            +
                        act_fn = GEGLU(dim, inner_dim, bias=bias)
         | 
| 63 | 
            +
                    elif activation_fn == "geglu-approximate":
         | 
| 64 | 
            +
                        act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    self.net = nn.ModuleList([])
         | 
| 67 | 
            +
                    # project in
         | 
| 68 | 
            +
                    self.net.append(act_fn)
         | 
| 69 | 
            +
                    # project dropout
         | 
| 70 | 
            +
                    self.net.append(nn.Dropout(dropout))
         | 
| 71 | 
            +
                    # project out
         | 
| 72 | 
            +
                    self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
         | 
| 73 | 
            +
                    # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
         | 
| 74 | 
            +
                    if final_dropout:
         | 
| 75 | 
            +
                        self.net.append(nn.Dropout(dropout))
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
         | 
| 78 | 
            +
                    if len(args) > 0 or kwargs.get("scale", None) is not None:
         | 
| 79 | 
            +
                        deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
         | 
| 80 | 
            +
                        deprecate("scale", "1.0.0", deprecation_message)
         | 
| 81 | 
            +
                    for module in self.net:
         | 
| 82 | 
            +
                        hidden_states = module(hidden_states)
         | 
| 83 | 
            +
                    return hidden_states
         | 
| 84 | 
            +
             | 
| 85 | 
            +
             | 
| 86 | 
            +
            class VarlenFlashSelfAttentionWithT5Mask:
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                def __init__(self):
         | 
| 89 | 
            +
                    pass
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                def apply_rope(self, xq, xk, freqs_cis):
         | 
| 92 | 
            +
                    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
         | 
| 93 | 
            +
                    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
         | 
| 94 | 
            +
                    xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
         | 
| 95 | 
            +
                    xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
         | 
| 96 | 
            +
                    return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                def __call__(
         | 
| 99 | 
            +
                        self, query, key, value, encoder_query, encoder_key, encoder_value, 
         | 
| 100 | 
            +
                        heads, scale, hidden_length=None, image_rotary_emb=None, encoder_attention_mask=None,
         | 
| 101 | 
            +
                    ):
         | 
| 102 | 
            +
                    assert encoder_attention_mask is not None, "The encoder-hidden mask needed to be set"
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    batch_size = query.shape[0]
         | 
| 105 | 
            +
                    output_hidden = torch.zeros_like(query)
         | 
| 106 | 
            +
                    output_encoder_hidden = torch.zeros_like(encoder_query)
         | 
| 107 | 
            +
                    encoder_length = encoder_query.shape[1]
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    qkv_list = []
         | 
| 110 | 
            +
                    num_stages = len(hidden_length)        
         | 
| 111 | 
            +
                
         | 
| 112 | 
            +
                    encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
         | 
| 113 | 
            +
                    qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    i_sum = 0
         | 
| 116 | 
            +
                    for i_p, length in enumerate(hidden_length):
         | 
| 117 | 
            +
                        encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
         | 
| 118 | 
            +
                        qkv_tokens = qkv[:, i_sum:i_sum+length]
         | 
| 119 | 
            +
                        concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1)  # [bs, tot_seq, 3, nhead, dim]
         | 
| 120 | 
            +
                        
         | 
| 121 | 
            +
                        if image_rotary_emb is not None:
         | 
| 122 | 
            +
                            concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = self.apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                        indices = encoder_attention_mask[i_p]['indices']
         | 
| 125 | 
            +
                        qkv_list.append(index_first_axis(rearrange(concat_qkv_tokens, "b s ... -> (b s) ..."), indices))
         | 
| 126 | 
            +
                        i_sum += length
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    token_lengths = [x_.shape[0] for x_ in qkv_list]
         | 
| 129 | 
            +
                    qkv = torch.cat(qkv_list, dim=0)
         | 
| 130 | 
            +
                    query, key, value = qkv.unbind(1)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    cu_seqlens = torch.cat([x_['seqlens_in_batch'] for x_ in encoder_attention_mask], dim=0)
         | 
| 133 | 
            +
                    max_seqlen_q = cu_seqlens.max().item()
         | 
| 134 | 
            +
                    max_seqlen_k = max_seqlen_q
         | 
| 135 | 
            +
                    cu_seqlens_q = F.pad(torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32), (1, 0))
         | 
| 136 | 
            +
                    cu_seqlens_k = cu_seqlens_q.clone()
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    output = flash_attn_varlen_func(
         | 
| 139 | 
            +
                        query,
         | 
| 140 | 
            +
                        key,
         | 
| 141 | 
            +
                        value,
         | 
| 142 | 
            +
                        cu_seqlens_q=cu_seqlens_q,
         | 
| 143 | 
            +
                        cu_seqlens_k=cu_seqlens_k,
         | 
| 144 | 
            +
                        max_seqlen_q=max_seqlen_q,
         | 
| 145 | 
            +
                        max_seqlen_k=max_seqlen_k,
         | 
| 146 | 
            +
                        dropout_p=0.0,
         | 
| 147 | 
            +
                        causal=False,
         | 
| 148 | 
            +
                        softmax_scale=scale,
         | 
| 149 | 
            +
                    )
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    # To merge the tokens
         | 
| 152 | 
            +
                    i_sum = 0;token_sum = 0
         | 
| 153 | 
            +
                    for i_p, length in enumerate(hidden_length):
         | 
| 154 | 
            +
                        tot_token_num = token_lengths[i_p]
         | 
| 155 | 
            +
                        stage_output = output[token_sum : token_sum + tot_token_num]
         | 
| 156 | 
            +
                        stage_output = pad_input(stage_output, encoder_attention_mask[i_p]['indices'], batch_size, encoder_length + length)
         | 
| 157 | 
            +
                        stage_encoder_hidden_output = stage_output[:, :encoder_length]
         | 
| 158 | 
            +
                        stage_hidden_output = stage_output[:, encoder_length:]   
         | 
| 159 | 
            +
                        output_hidden[:, i_sum:i_sum+length] = stage_hidden_output
         | 
| 160 | 
            +
                        output_encoder_hidden[i_p::num_stages] = stage_encoder_hidden_output
         | 
| 161 | 
            +
                        token_sum += tot_token_num
         | 
| 162 | 
            +
                        i_sum += length
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    output_hidden = output_hidden.flatten(2, 3)
         | 
| 165 | 
            +
                    output_encoder_hidden = output_encoder_hidden.flatten(2, 3)
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                    return output_hidden, output_encoder_hidden
         | 
| 168 | 
            +
             | 
| 169 | 
            +
             | 
| 170 | 
            +
            class SequenceParallelVarlenFlashSelfAttentionWithT5Mask:
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                def __init__(self):
         | 
| 173 | 
            +
                    pass
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                def apply_rope(self, xq, xk, freqs_cis):
         | 
| 176 | 
            +
                    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
         | 
| 177 | 
            +
                    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
         | 
| 178 | 
            +
                    xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
         | 
| 179 | 
            +
                    xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
         | 
| 180 | 
            +
                    return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                def __call__(
         | 
| 183 | 
            +
                        self, query, key, value, encoder_query, encoder_key, encoder_value, 
         | 
| 184 | 
            +
                        heads, scale, hidden_length=None, image_rotary_emb=None, encoder_attention_mask=None,
         | 
| 185 | 
            +
                    ):
         | 
| 186 | 
            +
                    assert encoder_attention_mask is not None, "The encoder-hidden mask needed to be set"
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    batch_size = query.shape[0]
         | 
| 189 | 
            +
                    qkv_list = []
         | 
| 190 | 
            +
                    num_stages = len(hidden_length)
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
         | 
| 193 | 
            +
                    qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    # To sync the encoder query, key and values
         | 
| 196 | 
            +
                    sp_group = get_sequence_parallel_group()
         | 
| 197 | 
            +
                    sp_group_size = get_sequence_parallel_world_size()
         | 
| 198 | 
            +
                    encoder_qkv = all_to_all(encoder_qkv, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    output_hidden = torch.zeros_like(qkv[:,:,0])
         | 
| 201 | 
            +
                    output_encoder_hidden = torch.zeros_like(encoder_qkv[:,:,0])
         | 
| 202 | 
            +
                    encoder_length = encoder_qkv.shape[1]
         | 
| 203 | 
            +
                    
         | 
| 204 | 
            +
                    i_sum = 0
         | 
| 205 | 
            +
                    for i_p, length in enumerate(hidden_length):
         | 
| 206 | 
            +
                        # get the query, key, value from padding sequence
         | 
| 207 | 
            +
                        encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
         | 
| 208 | 
            +
                        qkv_tokens = qkv[:, i_sum:i_sum+length]
         | 
| 209 | 
            +
                        qkv_tokens = all_to_all(qkv_tokens, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
         | 
| 210 | 
            +
                        concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1)  # [bs, pad_seq, 3, nhead, dim]
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                        if image_rotary_emb is not None:
         | 
| 213 | 
            +
                            concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = self.apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                        indices = encoder_attention_mask[i_p]['indices']
         | 
| 216 | 
            +
                        qkv_list.append(index_first_axis(rearrange(concat_qkv_tokens, "b s ... -> (b s) ..."), indices))
         | 
| 217 | 
            +
                        i_sum += length
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    token_lengths = [x_.shape[0] for x_ in qkv_list]
         | 
| 220 | 
            +
                    qkv = torch.cat(qkv_list, dim=0)
         | 
| 221 | 
            +
                    query, key, value = qkv.unbind(1)
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                    cu_seqlens = torch.cat([x_['seqlens_in_batch'] for x_ in encoder_attention_mask], dim=0)
         | 
| 224 | 
            +
                    max_seqlen_q = cu_seqlens.max().item()
         | 
| 225 | 
            +
                    max_seqlen_k = max_seqlen_q
         | 
| 226 | 
            +
                    cu_seqlens_q = F.pad(torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32), (1, 0))
         | 
| 227 | 
            +
                    cu_seqlens_k = cu_seqlens_q.clone()
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    output = flash_attn_varlen_func(
         | 
| 230 | 
            +
                        query,
         | 
| 231 | 
            +
                        key,
         | 
| 232 | 
            +
                        value,
         | 
| 233 | 
            +
                        cu_seqlens_q=cu_seqlens_q,
         | 
| 234 | 
            +
                        cu_seqlens_k=cu_seqlens_k,
         | 
| 235 | 
            +
                        max_seqlen_q=max_seqlen_q,
         | 
| 236 | 
            +
                        max_seqlen_k=max_seqlen_k,
         | 
| 237 | 
            +
                        dropout_p=0.0,
         | 
| 238 | 
            +
                        causal=False,
         | 
| 239 | 
            +
                        softmax_scale=scale,
         | 
| 240 | 
            +
                    )
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    # To merge the tokens
         | 
| 243 | 
            +
                    i_sum = 0;token_sum = 0
         | 
| 244 | 
            +
                    for i_p, length in enumerate(hidden_length):
         | 
| 245 | 
            +
                        tot_token_num = token_lengths[i_p]
         | 
| 246 | 
            +
                        stage_output = output[token_sum : token_sum + tot_token_num]
         | 
| 247 | 
            +
                        stage_output = pad_input(stage_output, encoder_attention_mask[i_p]['indices'], batch_size, encoder_length + length * sp_group_size)
         | 
| 248 | 
            +
                        stage_encoder_hidden_output = stage_output[:, :encoder_length]
         | 
| 249 | 
            +
                        stage_hidden_output = stage_output[:, encoder_length:]
         | 
| 250 | 
            +
                        stage_hidden_output = all_to_all(stage_hidden_output, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
         | 
| 251 | 
            +
                        output_hidden[:, i_sum:i_sum+length] = stage_hidden_output
         | 
| 252 | 
            +
                        output_encoder_hidden[i_p::num_stages] = stage_encoder_hidden_output
         | 
| 253 | 
            +
                        token_sum += tot_token_num
         | 
| 254 | 
            +
                        i_sum += length
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    output_encoder_hidden = all_to_all(output_encoder_hidden, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
         | 
| 257 | 
            +
                    output_hidden = output_hidden.flatten(2, 3)
         | 
| 258 | 
            +
                    output_encoder_hidden = output_encoder_hidden.flatten(2, 3)
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                    return output_hidden, output_encoder_hidden
         | 
| 261 | 
            +
             | 
| 262 | 
            +
             | 
| 263 | 
            +
            class VarlenSelfAttentionWithT5Mask:
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                """
         | 
| 266 | 
            +
                    For chunk stage attention without using flash attention
         | 
| 267 | 
            +
                """
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                def __init__(self):
         | 
| 270 | 
            +
                    pass
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                def apply_rope(self, xq, xk, freqs_cis):
         | 
| 273 | 
            +
                    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
         | 
| 274 | 
            +
                    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
         | 
| 275 | 
            +
                    xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
         | 
| 276 | 
            +
                    xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
         | 
| 277 | 
            +
                    return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                def __call__(
         | 
| 280 | 
            +
                        self, query, key, value, encoder_query, encoder_key, encoder_value, 
         | 
| 281 | 
            +
                        heads, scale, hidden_length=None, image_rotary_emb=None, attention_mask=None,
         | 
| 282 | 
            +
                    ):
         | 
| 283 | 
            +
                    assert attention_mask is not None, "The attention mask needed to be set"
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                    encoder_length = encoder_query.shape[1]
         | 
| 286 | 
            +
                    num_stages = len(hidden_length)        
         | 
| 287 | 
            +
                
         | 
| 288 | 
            +
                    encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
         | 
| 289 | 
            +
                    qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                    i_sum = 0
         | 
| 292 | 
            +
                    output_encoder_hidden_list = []
         | 
| 293 | 
            +
                    output_hidden_list = []
         | 
| 294 | 
            +
                
         | 
| 295 | 
            +
                    for i_p, length in enumerate(hidden_length):
         | 
| 296 | 
            +
                        encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
         | 
| 297 | 
            +
                        qkv_tokens = qkv[:, i_sum:i_sum+length]
         | 
| 298 | 
            +
                        concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1)  # [bs, tot_seq, 3, nhead, dim]
         | 
| 299 | 
            +
                        
         | 
| 300 | 
            +
                        if image_rotary_emb is not None:
         | 
| 301 | 
            +
                            concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = self.apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                        query, key, value = concat_qkv_tokens.unbind(2)   # [bs, tot_seq, nhead, dim]
         | 
| 304 | 
            +
                        query = query.transpose(1, 2)
         | 
| 305 | 
            +
                        key = key.transpose(1, 2)
         | 
| 306 | 
            +
                        value = value.transpose(1, 2)
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                        # with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=False, enable_mem_efficient=True):
         | 
| 309 | 
            +
                        stage_hidden_states = F.scaled_dot_product_attention(
         | 
| 310 | 
            +
                            query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask[i_p],
         | 
| 311 | 
            +
                        )
         | 
| 312 | 
            +
                        stage_hidden_states = stage_hidden_states.transpose(1, 2).flatten(2, 3)   # [bs, tot_seq, dim]
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                        output_encoder_hidden_list.append(stage_hidden_states[:, :encoder_length])
         | 
| 315 | 
            +
                        output_hidden_list.append(stage_hidden_states[:, encoder_length:])
         | 
| 316 | 
            +
                        i_sum += length
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                    output_encoder_hidden = torch.stack(output_encoder_hidden_list, dim=1)  # [b n s d]
         | 
| 319 | 
            +
                    output_encoder_hidden = rearrange(output_encoder_hidden, 'b n s d -> (b n) s d')
         | 
| 320 | 
            +
                    output_hidden = torch.cat(output_hidden_list, dim=1)
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                    return output_hidden, output_encoder_hidden
         | 
| 323 | 
            +
             | 
| 324 | 
            +
             | 
| 325 | 
            +
            class SequenceParallelVarlenSelfAttentionWithT5Mask:
         | 
| 326 | 
            +
                """
         | 
| 327 | 
            +
                    For chunk stage attention without using flash attention
         | 
| 328 | 
            +
                """
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                def __init__(self):
         | 
| 331 | 
            +
                    pass
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                def apply_rope(self, xq, xk, freqs_cis):
         | 
| 334 | 
            +
                    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
         | 
| 335 | 
            +
                    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
         | 
| 336 | 
            +
                    xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
         | 
| 337 | 
            +
                    xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
         | 
| 338 | 
            +
                    return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                def __call__(
         | 
| 341 | 
            +
                        self, query, key, value, encoder_query, encoder_key, encoder_value, 
         | 
| 342 | 
            +
                        heads, scale, hidden_length=None, image_rotary_emb=None, attention_mask=None,
         | 
| 343 | 
            +
                    ):
         | 
| 344 | 
            +
                    assert attention_mask is not None, "The attention mask needed to be set"
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                    num_stages = len(hidden_length)        
         | 
| 347 | 
            +
                
         | 
| 348 | 
            +
                    encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
         | 
| 349 | 
            +
                    qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                    # To sync the encoder query, key and values
         | 
| 352 | 
            +
                    sp_group = get_sequence_parallel_group()
         | 
| 353 | 
            +
                    sp_group_size = get_sequence_parallel_world_size()
         | 
| 354 | 
            +
                    encoder_qkv = all_to_all(encoder_qkv, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
         | 
| 355 | 
            +
                    encoder_length = encoder_qkv.shape[1]
         | 
| 356 | 
            +
             | 
| 357 | 
            +
                    i_sum = 0
         | 
| 358 | 
            +
                    output_encoder_hidden_list = []
         | 
| 359 | 
            +
                    output_hidden_list = []
         | 
| 360 | 
            +
                
         | 
| 361 | 
            +
                    for i_p, length in enumerate(hidden_length):
         | 
| 362 | 
            +
                        encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
         | 
| 363 | 
            +
                        qkv_tokens = qkv[:, i_sum:i_sum+length]
         | 
| 364 | 
            +
                        qkv_tokens = all_to_all(qkv_tokens, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
         | 
| 365 | 
            +
                        concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1)  # [bs, tot_seq, 3, nhead, dim]
         | 
| 366 | 
            +
                        
         | 
| 367 | 
            +
                        if image_rotary_emb is not None:
         | 
| 368 | 
            +
                            concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = self.apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                        query, key, value = concat_qkv_tokens.unbind(2)   # [bs, tot_seq, nhead, dim]
         | 
| 371 | 
            +
                        query = query.transpose(1, 2)
         | 
| 372 | 
            +
                        key = key.transpose(1, 2)
         | 
| 373 | 
            +
                        value = value.transpose(1, 2)
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                        stage_hidden_states = F.scaled_dot_product_attention(
         | 
| 376 | 
            +
                            query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask[i_p],
         | 
| 377 | 
            +
                        )
         | 
| 378 | 
            +
                        stage_hidden_states = stage_hidden_states.transpose(1, 2)   # [bs, tot_seq, nhead, dim]
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                        output_encoder_hidden_list.append(stage_hidden_states[:, :encoder_length])
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                        output_hidden = stage_hidden_states[:, encoder_length:]
         | 
| 383 | 
            +
                        output_hidden = all_to_all(output_hidden, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
         | 
| 384 | 
            +
                        output_hidden_list.append(output_hidden)
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                        i_sum += length
         | 
| 387 | 
            +
             | 
| 388 | 
            +
                    output_encoder_hidden = torch.stack(output_encoder_hidden_list, dim=1)  # [b n s nhead d]
         | 
| 389 | 
            +
                    output_encoder_hidden = rearrange(output_encoder_hidden, 'b n s h d -> (b n) s h d')
         | 
| 390 | 
            +
                    output_encoder_hidden = all_to_all(output_encoder_hidden, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
         | 
| 391 | 
            +
                    output_encoder_hidden = output_encoder_hidden.flatten(2, 3)
         | 
| 392 | 
            +
                    output_hidden = torch.cat(output_hidden_list, dim=1).flatten(2, 3)
         | 
| 393 | 
            +
             | 
| 394 | 
            +
                    return output_hidden, output_encoder_hidden
         | 
| 395 | 
            +
             | 
| 396 | 
            +
             | 
| 397 | 
            +
            class JointAttention(nn.Module):
         | 
| 398 | 
            +
                
         | 
| 399 | 
            +
                def __init__(
         | 
| 400 | 
            +
                    self,
         | 
| 401 | 
            +
                    query_dim: int,
         | 
| 402 | 
            +
                    cross_attention_dim: Optional[int] = None,
         | 
| 403 | 
            +
                    heads: int = 8,
         | 
| 404 | 
            +
                    dim_head: int = 64,
         | 
| 405 | 
            +
                    dropout: float = 0.0,
         | 
| 406 | 
            +
                    bias: bool = False,
         | 
| 407 | 
            +
                    qk_norm: Optional[str] = None,
         | 
| 408 | 
            +
                    added_kv_proj_dim: Optional[int] = None,
         | 
| 409 | 
            +
                    out_bias: bool = True,
         | 
| 410 | 
            +
                    eps: float = 1e-5,
         | 
| 411 | 
            +
                    out_dim: int = None,
         | 
| 412 | 
            +
                    context_pre_only=None,
         | 
| 413 | 
            +
                    use_flash_attn=True,
         | 
| 414 | 
            +
                ): 
         | 
| 415 | 
            +
                    """
         | 
| 416 | 
            +
                        Fixing the QKNorm, following the flux, norm the head dimension
         | 
| 417 | 
            +
                    """
         | 
| 418 | 
            +
                    super().__init__()
         | 
| 419 | 
            +
                    self.inner_dim = out_dim if out_dim is not None else dim_head * heads
         | 
| 420 | 
            +
                    self.query_dim = query_dim
         | 
| 421 | 
            +
                    self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
         | 
| 422 | 
            +
                    self.use_bias = bias
         | 
| 423 | 
            +
                    self.dropout = dropout
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                    self.out_dim = out_dim if out_dim is not None else query_dim
         | 
| 426 | 
            +
                    self.context_pre_only = context_pre_only
         | 
| 427 | 
            +
             | 
| 428 | 
            +
                    self.scale = dim_head**-0.5
         | 
| 429 | 
            +
                    self.heads = out_dim // dim_head if out_dim is not None else heads
         | 
| 430 | 
            +
                    self.added_kv_proj_dim = added_kv_proj_dim
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                    if qk_norm is None:
         | 
| 433 | 
            +
                        self.norm_q = None
         | 
| 434 | 
            +
                        self.norm_k = None
         | 
| 435 | 
            +
                    elif qk_norm == "layer_norm":
         | 
| 436 | 
            +
                        self.norm_q = nn.LayerNorm(dim_head, eps=eps)
         | 
| 437 | 
            +
                        self.norm_k = nn.LayerNorm(dim_head, eps=eps)
         | 
| 438 | 
            +
                    elif qk_norm == 'rms_norm':
         | 
| 439 | 
            +
                        self.norm_q = RMSNorm(dim_head, eps=eps)
         | 
| 440 | 
            +
                        self.norm_k = RMSNorm(dim_head, eps=eps)
         | 
| 441 | 
            +
                    else:
         | 
| 442 | 
            +
                        raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                    self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
         | 
| 445 | 
            +
                    self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
         | 
| 446 | 
            +
                    self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
         | 
| 447 | 
            +
                
         | 
| 448 | 
            +
                    if self.added_kv_proj_dim is not None:
         | 
| 449 | 
            +
                        self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
         | 
| 450 | 
            +
                        self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
         | 
| 451 | 
            +
                        self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
         | 
| 452 | 
            +
             | 
| 453 | 
            +
                        if qk_norm is None:
         | 
| 454 | 
            +
                            self.norm_add_q = None
         | 
| 455 | 
            +
                            self.norm_add_k = None
         | 
| 456 | 
            +
                        elif qk_norm == "layer_norm":
         | 
| 457 | 
            +
                            self.norm_add_q = nn.LayerNorm(dim_head, eps=eps)
         | 
| 458 | 
            +
                            self.norm_add_k = nn.LayerNorm(dim_head, eps=eps)
         | 
| 459 | 
            +
                        elif qk_norm == 'rms_norm':
         | 
| 460 | 
            +
                            self.norm_add_q = RMSNorm(dim_head, eps=eps)
         | 
| 461 | 
            +
                            self.norm_add_k = RMSNorm(dim_head, eps=eps)
         | 
| 462 | 
            +
                        else:
         | 
| 463 | 
            +
                            raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
         | 
| 464 | 
            +
             | 
| 465 | 
            +
                    self.to_out = nn.ModuleList([])
         | 
| 466 | 
            +
                    self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
         | 
| 467 | 
            +
                    self.to_out.append(nn.Dropout(dropout))
         | 
| 468 | 
            +
             | 
| 469 | 
            +
                    if not self.context_pre_only:
         | 
| 470 | 
            +
                        self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
         | 
| 471 | 
            +
             | 
| 472 | 
            +
                    self.use_flash_attn = use_flash_attn
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                    if flash_attn_func is None:
         | 
| 475 | 
            +
                        self.use_flash_attn = False
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                    # print(f"Using flash-attention: {self.use_flash_attn}")
         | 
| 478 | 
            +
                    if self.use_flash_attn:
         | 
| 479 | 
            +
                        if is_sequence_parallel_initialized():
         | 
| 480 | 
            +
                            self.var_flash_attn = SequenceParallelVarlenFlashSelfAttentionWithT5Mask()
         | 
| 481 | 
            +
                        else:
         | 
| 482 | 
            +
                            self.var_flash_attn = VarlenFlashSelfAttentionWithT5Mask()
         | 
| 483 | 
            +
                    else:
         | 
| 484 | 
            +
                        if is_sequence_parallel_initialized():
         | 
| 485 | 
            +
                            self.var_len_attn = SequenceParallelVarlenSelfAttentionWithT5Mask()
         | 
| 486 | 
            +
                        else:
         | 
| 487 | 
            +
                            self.var_len_attn = VarlenSelfAttentionWithT5Mask()
         | 
| 488 | 
            +
                
         | 
| 489 | 
            +
             | 
| 490 | 
            +
                def forward(
         | 
| 491 | 
            +
                    self,
         | 
| 492 | 
            +
                    hidden_states: torch.FloatTensor,
         | 
| 493 | 
            +
                    encoder_hidden_states: torch.FloatTensor = None,
         | 
| 494 | 
            +
                    encoder_attention_mask: torch.FloatTensor = None,
         | 
| 495 | 
            +
                    attention_mask: torch.FloatTensor = None,   # [B, L, S]
         | 
| 496 | 
            +
                    hidden_length: torch.Tensor = None,
         | 
| 497 | 
            +
                    image_rotary_emb: torch.Tensor = None,
         | 
| 498 | 
            +
                    **kwargs,
         | 
| 499 | 
            +
                ) -> torch.FloatTensor:
         | 
| 500 | 
            +
                    # This function is only used during training
         | 
| 501 | 
            +
                    # `sample` projections.
         | 
| 502 | 
            +
                    query = self.to_q(hidden_states)
         | 
| 503 | 
            +
                    key = self.to_k(hidden_states)
         | 
| 504 | 
            +
                    value = self.to_v(hidden_states)
         | 
| 505 | 
            +
             | 
| 506 | 
            +
                    inner_dim = key.shape[-1]
         | 
| 507 | 
            +
                    head_dim = inner_dim // self.heads
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                    query = query.view(query.shape[0], -1, self.heads, head_dim)
         | 
| 510 | 
            +
                    key = key.view(key.shape[0], -1, self.heads, head_dim)
         | 
| 511 | 
            +
                    value = value.view(value.shape[0], -1, self.heads, head_dim)
         | 
| 512 | 
            +
             | 
| 513 | 
            +
                    if self.norm_q is not None:
         | 
| 514 | 
            +
                        query = self.norm_q(query)
         | 
| 515 | 
            +
             | 
| 516 | 
            +
                    if self.norm_k is not None:
         | 
| 517 | 
            +
                        key = self.norm_k(key)
         | 
| 518 | 
            +
             | 
| 519 | 
            +
                    # `context` projections.
         | 
| 520 | 
            +
                    encoder_hidden_states_query_proj = self.add_q_proj(encoder_hidden_states)
         | 
| 521 | 
            +
                    encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
         | 
| 522 | 
            +
                    encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
         | 
| 523 | 
            +
             | 
| 524 | 
            +
                    encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
         | 
| 525 | 
            +
                        encoder_hidden_states_query_proj.shape[0], -1, self.heads, head_dim
         | 
| 526 | 
            +
                    )
         | 
| 527 | 
            +
                    encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
         | 
| 528 | 
            +
                        encoder_hidden_states_key_proj.shape[0], -1, self.heads, head_dim
         | 
| 529 | 
            +
                    )
         | 
| 530 | 
            +
                    encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
         | 
| 531 | 
            +
                        encoder_hidden_states_value_proj.shape[0], -1, self.heads, head_dim
         | 
| 532 | 
            +
                    )
         | 
| 533 | 
            +
             | 
| 534 | 
            +
                    if self.norm_add_q is not None:
         | 
| 535 | 
            +
                        encoder_hidden_states_query_proj = self.norm_add_q(encoder_hidden_states_query_proj)
         | 
| 536 | 
            +
             | 
| 537 | 
            +
                    if self.norm_add_k is not None:
         | 
| 538 | 
            +
                        encoder_hidden_states_key_proj = self.norm_add_k(encoder_hidden_states_key_proj)
         | 
| 539 | 
            +
             | 
| 540 | 
            +
                    # To cat the hidden and encoder hidden, perform attention compuataion, and then split
         | 
| 541 | 
            +
                    if self.use_flash_attn:
         | 
| 542 | 
            +
                        hidden_states, encoder_hidden_states = self.var_flash_attn(
         | 
| 543 | 
            +
                            query, key, value, 
         | 
| 544 | 
            +
                            encoder_hidden_states_query_proj, encoder_hidden_states_key_proj,
         | 
| 545 | 
            +
                            encoder_hidden_states_value_proj, self.heads, self.scale, hidden_length, 
         | 
| 546 | 
            +
                            image_rotary_emb, encoder_attention_mask,
         | 
| 547 | 
            +
                        )
         | 
| 548 | 
            +
                    else:
         | 
| 549 | 
            +
                        hidden_states, encoder_hidden_states = self.var_len_attn(
         | 
| 550 | 
            +
                            query, key, value, 
         | 
| 551 | 
            +
                            encoder_hidden_states_query_proj, encoder_hidden_states_key_proj,
         | 
| 552 | 
            +
                            encoder_hidden_states_value_proj, self.heads, self.scale, hidden_length, 
         | 
| 553 | 
            +
                            image_rotary_emb, attention_mask,
         | 
| 554 | 
            +
                        )
         | 
| 555 | 
            +
             | 
| 556 | 
            +
                    # linear proj
         | 
| 557 | 
            +
                    hidden_states = self.to_out[0](hidden_states)
         | 
| 558 | 
            +
                    # dropout
         | 
| 559 | 
            +
                    hidden_states = self.to_out[1](hidden_states)
         | 
| 560 | 
            +
                    if not self.context_pre_only:
         | 
| 561 | 
            +
                        encoder_hidden_states = self.to_add_out(encoder_hidden_states)
         | 
| 562 | 
            +
             | 
| 563 | 
            +
                    return hidden_states, encoder_hidden_states
         | 
| 564 | 
            +
             | 
| 565 | 
            +
             | 
| 566 | 
            +
            class JointTransformerBlock(nn.Module):
         | 
| 567 | 
            +
                r"""
         | 
| 568 | 
            +
                A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
         | 
| 569 | 
            +
             | 
| 570 | 
            +
                Reference: https://arxiv.org/abs/2403.03206
         | 
| 571 | 
            +
             | 
| 572 | 
            +
                Parameters:
         | 
| 573 | 
            +
                    dim (`int`): The number of channels in the input and output.
         | 
| 574 | 
            +
                    num_attention_heads (`int`): The number of heads to use for multi-head attention.
         | 
| 575 | 
            +
                    attention_head_dim (`int`): The number of channels in each head.
         | 
| 576 | 
            +
                    context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
         | 
| 577 | 
            +
                        processing of `context` conditions.
         | 
| 578 | 
            +
                """
         | 
| 579 | 
            +
             | 
| 580 | 
            +
                def __init__(
         | 
| 581 | 
            +
                    self, dim, num_attention_heads, attention_head_dim, qk_norm=None, 
         | 
| 582 | 
            +
                    context_pre_only=False, use_flash_attn=True,
         | 
| 583 | 
            +
                ):
         | 
| 584 | 
            +
                    super().__init__()
         | 
| 585 | 
            +
             | 
| 586 | 
            +
                    self.context_pre_only = context_pre_only
         | 
| 587 | 
            +
                    context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
         | 
| 588 | 
            +
             | 
| 589 | 
            +
                    self.norm1 = AdaLayerNormZero(dim)
         | 
| 590 | 
            +
             | 
| 591 | 
            +
                    if context_norm_type == "ada_norm_continous":
         | 
| 592 | 
            +
                        self.norm1_context = AdaLayerNormContinuous(
         | 
| 593 | 
            +
                            dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
         | 
| 594 | 
            +
                        )
         | 
| 595 | 
            +
                    elif context_norm_type == "ada_norm_zero":
         | 
| 596 | 
            +
                        self.norm1_context = AdaLayerNormZero(dim)
         | 
| 597 | 
            +
                    else:
         | 
| 598 | 
            +
                        raise ValueError(
         | 
| 599 | 
            +
                            f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
         | 
| 600 | 
            +
                        )
         | 
| 601 | 
            +
             | 
| 602 | 
            +
                    self.attn = JointAttention(
         | 
| 603 | 
            +
                        query_dim=dim,
         | 
| 604 | 
            +
                        cross_attention_dim=None,
         | 
| 605 | 
            +
                        added_kv_proj_dim=dim,
         | 
| 606 | 
            +
                        dim_head=attention_head_dim // num_attention_heads,
         | 
| 607 | 
            +
                        heads=num_attention_heads,
         | 
| 608 | 
            +
                        out_dim=attention_head_dim,
         | 
| 609 | 
            +
                        qk_norm=qk_norm,
         | 
| 610 | 
            +
                        context_pre_only=context_pre_only,
         | 
| 611 | 
            +
                        bias=True,
         | 
| 612 | 
            +
                        use_flash_attn=use_flash_attn,
         | 
| 613 | 
            +
                    )
         | 
| 614 | 
            +
                    
         | 
| 615 | 
            +
                    self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
         | 
| 616 | 
            +
                    self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
         | 
| 617 | 
            +
             | 
| 618 | 
            +
                    if not context_pre_only:
         | 
| 619 | 
            +
                        self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
         | 
| 620 | 
            +
                        self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
         | 
| 621 | 
            +
                    else:
         | 
| 622 | 
            +
                        self.norm2_context = None
         | 
| 623 | 
            +
                        self.ff_context = None
         | 
| 624 | 
            +
             | 
| 625 | 
            +
                def forward(
         | 
| 626 | 
            +
                    self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, 
         | 
| 627 | 
            +
                    encoder_attention_mask: torch.FloatTensor, temb: torch.FloatTensor, 
         | 
| 628 | 
            +
                    attention_mask: torch.FloatTensor = None, hidden_length: List = None, 
         | 
| 629 | 
            +
                    image_rotary_emb: torch.FloatTensor = None,
         | 
| 630 | 
            +
                ):
         | 
| 631 | 
            +
                    norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb, hidden_length=hidden_length)
         | 
| 632 | 
            +
             | 
| 633 | 
            +
                    if self.context_pre_only:
         | 
| 634 | 
            +
                        norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
         | 
| 635 | 
            +
                    else:
         | 
| 636 | 
            +
                        norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
         | 
| 637 | 
            +
                            encoder_hidden_states, emb=temb,
         | 
| 638 | 
            +
                        )
         | 
| 639 | 
            +
             | 
| 640 | 
            +
                    # Attention
         | 
| 641 | 
            +
                    attn_output, context_attn_output = self.attn(
         | 
| 642 | 
            +
                        hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states,
         | 
| 643 | 
            +
                        encoder_attention_mask=encoder_attention_mask, attention_mask=attention_mask, 
         | 
| 644 | 
            +
                        hidden_length=hidden_length, image_rotary_emb=image_rotary_emb,
         | 
| 645 | 
            +
                    )
         | 
| 646 | 
            +
             | 
| 647 | 
            +
                    # Process attention outputs for the `hidden_states`.
         | 
| 648 | 
            +
                    attn_output = gate_msa * attn_output
         | 
| 649 | 
            +
                    hidden_states = hidden_states + attn_output
         | 
| 650 | 
            +
             | 
| 651 | 
            +
                    norm_hidden_states = self.norm2(hidden_states)
         | 
| 652 | 
            +
                    norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
         | 
| 653 | 
            +
             | 
| 654 | 
            +
                    ff_output = self.ff(norm_hidden_states)
         | 
| 655 | 
            +
                    ff_output = gate_mlp * ff_output
         | 
| 656 | 
            +
             | 
| 657 | 
            +
                    hidden_states = hidden_states + ff_output
         | 
| 658 | 
            +
             | 
| 659 | 
            +
                    # Process attention outputs for the `encoder_hidden_states`.
         | 
| 660 | 
            +
                    if self.context_pre_only:
         | 
| 661 | 
            +
                        encoder_hidden_states = None
         | 
| 662 | 
            +
                    else:
         | 
| 663 | 
            +
                        context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
         | 
| 664 | 
            +
                        encoder_hidden_states = encoder_hidden_states + context_attn_output
         | 
| 665 | 
            +
             | 
| 666 | 
            +
                        norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
         | 
| 667 | 
            +
                        norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
         | 
| 668 | 
            +
                        
         | 
| 669 | 
            +
                        context_ff_output = self.ff_context(norm_encoder_hidden_states)
         | 
| 670 | 
            +
                        encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
         | 
| 671 | 
            +
             | 
| 672 | 
            +
                    return encoder_hidden_states, hidden_states
         | 
    	
        pyramid_dit/modeling_normalization.py
    ADDED
    
    | @@ -0,0 +1,179 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numbers
         | 
| 2 | 
            +
            from typing import Dict, Optional, Tuple
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn as nn
         | 
| 6 | 
            +
            import torch.nn.functional as F
         | 
| 7 | 
            +
            from einops import rearrange
         | 
| 8 | 
            +
            from diffusers.utils import is_torch_version
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            if is_torch_version(">=", "2.1.0"):
         | 
| 12 | 
            +
                LayerNorm = nn.LayerNorm
         | 
| 13 | 
            +
            else:
         | 
| 14 | 
            +
                # Has optional bias parameter compared to torch layer norm
         | 
| 15 | 
            +
                # TODO: replace with torch layernorm once min required torch version >= 2.1
         | 
| 16 | 
            +
                class LayerNorm(nn.Module):
         | 
| 17 | 
            +
                    def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
         | 
| 18 | 
            +
                        super().__init__()
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                        self.eps = eps
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                        if isinstance(dim, numbers.Integral):
         | 
| 23 | 
            +
                            dim = (dim,)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                        self.dim = torch.Size(dim)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                        if elementwise_affine:
         | 
| 28 | 
            +
                            self.weight = nn.Parameter(torch.ones(dim))
         | 
| 29 | 
            +
                            self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
         | 
| 30 | 
            +
                        else:
         | 
| 31 | 
            +
                            self.weight = None
         | 
| 32 | 
            +
                            self.bias = None
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    def forward(self, input):
         | 
| 35 | 
            +
                        return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            class RMSNorm(nn.Module):
         | 
| 39 | 
            +
                def __init__(self, dim, eps: float, elementwise_affine: bool = True):
         | 
| 40 | 
            +
                    super().__init__()
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    self.eps = eps
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    if isinstance(dim, numbers.Integral):
         | 
| 45 | 
            +
                        dim = (dim,)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    self.dim = torch.Size(dim)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    if elementwise_affine:
         | 
| 50 | 
            +
                        self.weight = nn.Parameter(torch.ones(dim))
         | 
| 51 | 
            +
                    else:
         | 
| 52 | 
            +
                        self.weight = None
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                def forward(self, hidden_states):
         | 
| 55 | 
            +
                    input_dtype = hidden_states.dtype
         | 
| 56 | 
            +
                    variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
         | 
| 57 | 
            +
                    hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    if self.weight is not None:
         | 
| 60 | 
            +
                        # convert into half-precision if necessary
         | 
| 61 | 
            +
                        if self.weight.dtype in [torch.float16, torch.bfloat16]:
         | 
| 62 | 
            +
                            hidden_states = hidden_states.to(self.weight.dtype)
         | 
| 63 | 
            +
                        hidden_states = hidden_states * self.weight
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    hidden_states = hidden_states.to(input_dtype)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    return hidden_states
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            class AdaLayerNormContinuous(nn.Module):
         | 
| 71 | 
            +
                def __init__(
         | 
| 72 | 
            +
                    self,
         | 
| 73 | 
            +
                    embedding_dim: int,
         | 
| 74 | 
            +
                    conditioning_embedding_dim: int,
         | 
| 75 | 
            +
                    # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
         | 
| 76 | 
            +
                    # because the output is immediately scaled and shifted by the projected conditioning embeddings.
         | 
| 77 | 
            +
                    # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
         | 
| 78 | 
            +
                    # However, this is how it was implemented in the original code, and it's rather likely you should
         | 
| 79 | 
            +
                    # set `elementwise_affine` to False.
         | 
| 80 | 
            +
                    elementwise_affine=True,
         | 
| 81 | 
            +
                    eps=1e-5,
         | 
| 82 | 
            +
                    bias=True,
         | 
| 83 | 
            +
                    norm_type="layer_norm",
         | 
| 84 | 
            +
                ):
         | 
| 85 | 
            +
                    super().__init__()
         | 
| 86 | 
            +
                    self.silu = nn.SiLU()
         | 
| 87 | 
            +
                    self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
         | 
| 88 | 
            +
                    if norm_type == "layer_norm":
         | 
| 89 | 
            +
                        self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
         | 
| 90 | 
            +
                    elif norm_type == "rms_norm":
         | 
| 91 | 
            +
                        self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
         | 
| 92 | 
            +
                    else:
         | 
| 93 | 
            +
                        raise ValueError(f"unknown norm_type {norm_type}")
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                def forward_with_pad(self, x: torch.Tensor, conditioning_embedding: torch.Tensor, hidden_length=None) -> torch.Tensor:
         | 
| 96 | 
            +
                    assert hidden_length is not None
         | 
| 97 | 
            +
                    
         | 
| 98 | 
            +
                    emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
         | 
| 99 | 
            +
                    batch_emb = torch.zeros_like(x).repeat(1, 1, 2)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    i_sum = 0
         | 
| 102 | 
            +
                    num_stages = len(hidden_length)
         | 
| 103 | 
            +
                    for i_p, length in enumerate(hidden_length):
         | 
| 104 | 
            +
                        batch_emb[:, i_sum:i_sum+length] = emb[i_p::num_stages][:,None]
         | 
| 105 | 
            +
                        i_sum += length
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    batch_scale, batch_shift = torch.chunk(batch_emb, 2, dim=2)
         | 
| 108 | 
            +
                    x = self.norm(x) * (1 + batch_scale) + batch_shift
         | 
| 109 | 
            +
                    return x
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor, hidden_length=None) -> torch.Tensor:
         | 
| 112 | 
            +
                    # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
         | 
| 113 | 
            +
                    if hidden_length is not None:
         | 
| 114 | 
            +
                        return self.forward_with_pad(x, conditioning_embedding, hidden_length)
         | 
| 115 | 
            +
                    emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
         | 
| 116 | 
            +
                    scale, shift = torch.chunk(emb, 2, dim=1)
         | 
| 117 | 
            +
                    x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
         | 
| 118 | 
            +
                    return x
         | 
| 119 | 
            +
             | 
| 120 | 
            +
             | 
| 121 | 
            +
            class AdaLayerNormZero(nn.Module):
         | 
| 122 | 
            +
                r"""
         | 
| 123 | 
            +
                Norm layer adaptive layer norm zero (adaLN-Zero).
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                Parameters:
         | 
| 126 | 
            +
                    embedding_dim (`int`): The size of each embedding vector.
         | 
| 127 | 
            +
                    num_embeddings (`int`): The size of the embeddings dictionary.
         | 
| 128 | 
            +
                """
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None):
         | 
| 131 | 
            +
                    super().__init__()
         | 
| 132 | 
            +
                    self.emb = None
         | 
| 133 | 
            +
                    self.silu = nn.SiLU()
         | 
| 134 | 
            +
                    self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
         | 
| 135 | 
            +
                    self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                def forward_with_pad(
         | 
| 138 | 
            +
                    self,
         | 
| 139 | 
            +
                    x: torch.Tensor,
         | 
| 140 | 
            +
                    timestep: Optional[torch.Tensor] = None,
         | 
| 141 | 
            +
                    class_labels: Optional[torch.LongTensor] = None,
         | 
| 142 | 
            +
                    hidden_dtype: Optional[torch.dtype] = None,
         | 
| 143 | 
            +
                    emb: Optional[torch.Tensor] = None,
         | 
| 144 | 
            +
                    hidden_length: Optional[torch.Tensor] = None,
         | 
| 145 | 
            +
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 146 | 
            +
                    # x: [bs, seq_len, dim]
         | 
| 147 | 
            +
                    if self.emb is not None:
         | 
| 148 | 
            +
                        emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    emb = self.linear(self.silu(emb))
         | 
| 151 | 
            +
                    batch_emb = torch.zeros_like(x).repeat(1, 1, 6)
         | 
| 152 | 
            +
                
         | 
| 153 | 
            +
                    i_sum = 0
         | 
| 154 | 
            +
                    num_stages = len(hidden_length)
         | 
| 155 | 
            +
                    for i_p, length in enumerate(hidden_length):
         | 
| 156 | 
            +
                        batch_emb[:, i_sum:i_sum+length] = emb[i_p::num_stages][:,None]
         | 
| 157 | 
            +
                        i_sum += length
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    batch_shift_msa, batch_scale_msa, batch_gate_msa, batch_shift_mlp, batch_scale_mlp, batch_gate_mlp = batch_emb.chunk(6, dim=2)
         | 
| 160 | 
            +
                    x = self.norm(x) * (1 + batch_scale_msa) + batch_shift_msa
         | 
| 161 | 
            +
                    return x, batch_gate_msa, batch_shift_mlp, batch_scale_mlp, batch_gate_mlp
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                def forward(
         | 
| 164 | 
            +
                    self,
         | 
| 165 | 
            +
                    x: torch.Tensor,
         | 
| 166 | 
            +
                    timestep: Optional[torch.Tensor] = None,
         | 
| 167 | 
            +
                    class_labels: Optional[torch.LongTensor] = None,
         | 
| 168 | 
            +
                    hidden_dtype: Optional[torch.dtype] = None,
         | 
| 169 | 
            +
                    emb: Optional[torch.Tensor] = None,
         | 
| 170 | 
            +
                    hidden_length: Optional[torch.Tensor] = None,
         | 
| 171 | 
            +
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 172 | 
            +
                    if hidden_length is not None:
         | 
| 173 | 
            +
                        return self.forward_with_pad(x, timestep, class_labels, hidden_dtype, emb, hidden_length)
         | 
| 174 | 
            +
                    if self.emb is not None:
         | 
| 175 | 
            +
                        emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
         | 
| 176 | 
            +
                    emb = self.linear(self.silu(emb))
         | 
| 177 | 
            +
                    shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
         | 
| 178 | 
            +
                    x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
         | 
| 179 | 
            +
                    return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
         | 
    	
        pyramid_dit/modeling_pyramid_mmdit.py
    ADDED
    
    | @@ -0,0 +1,487 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import torch.nn.functional as F
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from einops import rearrange
         | 
| 7 | 
            +
            from diffusers.utils.torch_utils import randn_tensor
         | 
| 8 | 
            +
            from diffusers.models.modeling_utils import ModelMixin
         | 
| 9 | 
            +
            from diffusers.configuration_utils import ConfigMixin, register_to_config
         | 
| 10 | 
            +
            from diffusers.utils import is_torch_version
         | 
| 11 | 
            +
            from typing import Any, Callable, Dict, List, Optional, Union
         | 
| 12 | 
            +
            from tqdm import tqdm
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from .modeling_embedding import PatchEmbed3D, CombinedTimestepConditionEmbeddings
         | 
| 15 | 
            +
            from .modeling_normalization import AdaLayerNormContinuous
         | 
| 16 | 
            +
            from .modeling_mmdit_block import JointTransformerBlock
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from trainer_misc import (
         | 
| 19 | 
            +
                is_sequence_parallel_initialized,
         | 
| 20 | 
            +
                get_sequence_parallel_group,
         | 
| 21 | 
            +
                get_sequence_parallel_world_size,
         | 
| 22 | 
            +
                get_sequence_parallel_rank,
         | 
| 23 | 
            +
                all_to_all,
         | 
| 24 | 
            +
            )
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            from IPython import embed
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
         | 
| 30 | 
            +
                assert dim % 2 == 0, "The dimension must be even."
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
         | 
| 33 | 
            +
                omega = 1.0 / (theta**scale)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                batch_size, seq_length = pos.shape
         | 
| 36 | 
            +
                out = torch.einsum("...n,d->...nd", pos, omega)
         | 
| 37 | 
            +
                cos_out = torch.cos(out)
         | 
| 38 | 
            +
                sin_out = torch.sin(out)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
         | 
| 41 | 
            +
                out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
         | 
| 42 | 
            +
                return out.float()
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            class EmbedNDRoPE(nn.Module):
         | 
| 46 | 
            +
                def __init__(self, dim: int, theta: int, axes_dim: List[int]):
         | 
| 47 | 
            +
                    super().__init__()
         | 
| 48 | 
            +
                    self.dim = dim
         | 
| 49 | 
            +
                    self.theta = theta
         | 
| 50 | 
            +
                    self.axes_dim = axes_dim
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def forward(self, ids: torch.Tensor) -> torch.Tensor:
         | 
| 53 | 
            +
                    n_axes = ids.shape[-1]
         | 
| 54 | 
            +
                    emb = torch.cat(
         | 
| 55 | 
            +
                        [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
         | 
| 56 | 
            +
                        dim=-3,
         | 
| 57 | 
            +
                    )
         | 
| 58 | 
            +
                    return emb.unsqueeze(2)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            class PyramidDiffusionMMDiT(ModelMixin, ConfigMixin):
         | 
| 62 | 
            +
                _supports_gradient_checkpointing = True
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                @register_to_config
         | 
| 65 | 
            +
                def __init__(
         | 
| 66 | 
            +
                    self,
         | 
| 67 | 
            +
                    sample_size: int = 128,
         | 
| 68 | 
            +
                    patch_size: int = 2,
         | 
| 69 | 
            +
                    in_channels: int = 16,
         | 
| 70 | 
            +
                    num_layers: int = 24,
         | 
| 71 | 
            +
                    attention_head_dim: int = 64,
         | 
| 72 | 
            +
                    num_attention_heads: int = 24,
         | 
| 73 | 
            +
                    caption_projection_dim: int = 1152,
         | 
| 74 | 
            +
                    pooled_projection_dim: int = 2048,
         | 
| 75 | 
            +
                    pos_embed_max_size: int = 192,
         | 
| 76 | 
            +
                    max_num_frames: int = 200,
         | 
| 77 | 
            +
                    qk_norm: str = 'rms_norm',
         | 
| 78 | 
            +
                    pos_embed_type: str = 'rope',
         | 
| 79 | 
            +
                    temp_pos_embed_type: str = 'sincos',
         | 
| 80 | 
            +
                    joint_attention_dim: int = 4096,
         | 
| 81 | 
            +
                    use_gradient_checkpointing: bool = False,
         | 
| 82 | 
            +
                    use_flash_attn: bool = True,
         | 
| 83 | 
            +
                    use_temporal_causal: bool = False,
         | 
| 84 | 
            +
                    use_t5_mask: bool = False,
         | 
| 85 | 
            +
                    add_temp_pos_embed: bool = False,
         | 
| 86 | 
            +
                    interp_condition_pos: bool = False,
         | 
| 87 | 
            +
                ):
         | 
| 88 | 
            +
                    super().__init__()
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    self.out_channels = in_channels
         | 
| 91 | 
            +
                    self.inner_dim = num_attention_heads * attention_head_dim
         | 
| 92 | 
            +
                    assert temp_pos_embed_type in ['rope', 'sincos']
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    # The input latent embeder, using the name pos_embed to remain the same with SD#
         | 
| 95 | 
            +
                    self.pos_embed = PatchEmbed3D(
         | 
| 96 | 
            +
                        height=sample_size,
         | 
| 97 | 
            +
                        width=sample_size,
         | 
| 98 | 
            +
                        patch_size=patch_size,
         | 
| 99 | 
            +
                        in_channels=in_channels,
         | 
| 100 | 
            +
                        embed_dim=self.inner_dim,
         | 
| 101 | 
            +
                        pos_embed_max_size=pos_embed_max_size,  # hard-code for now.
         | 
| 102 | 
            +
                        max_num_frames=max_num_frames,
         | 
| 103 | 
            +
                        pos_embed_type=pos_embed_type,
         | 
| 104 | 
            +
                        temp_pos_embed_type=temp_pos_embed_type,
         | 
| 105 | 
            +
                        add_temp_pos_embed=add_temp_pos_embed,
         | 
| 106 | 
            +
                        interp_condition_pos=interp_condition_pos,
         | 
| 107 | 
            +
                    )
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    # The RoPE EMbedding
         | 
| 110 | 
            +
                    if pos_embed_type == 'rope':
         | 
| 111 | 
            +
                        self.rope_embed = EmbedNDRoPE(self.inner_dim, 10000, axes_dim=[16, 24, 24])
         | 
| 112 | 
            +
                    else:
         | 
| 113 | 
            +
                        self.rope_embed = None
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    if temp_pos_embed_type == 'rope':
         | 
| 116 | 
            +
                        self.temp_rope_embed = EmbedNDRoPE(self.inner_dim, 10000, axes_dim=[attention_head_dim])
         | 
| 117 | 
            +
                    else:
         | 
| 118 | 
            +
                        self.temp_rope_embed = None
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    self.time_text_embed = CombinedTimestepConditionEmbeddings(
         | 
| 121 | 
            +
                        embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim,
         | 
| 122 | 
            +
                    )
         | 
| 123 | 
            +
                    self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    self.transformer_blocks = nn.ModuleList(
         | 
| 126 | 
            +
                        [
         | 
| 127 | 
            +
                            JointTransformerBlock(
         | 
| 128 | 
            +
                                dim=self.inner_dim,
         | 
| 129 | 
            +
                                num_attention_heads=num_attention_heads,
         | 
| 130 | 
            +
                                attention_head_dim=self.inner_dim,
         | 
| 131 | 
            +
                                qk_norm=qk_norm,
         | 
| 132 | 
            +
                                context_pre_only=i == num_layers - 1,
         | 
| 133 | 
            +
                                use_flash_attn=use_flash_attn,
         | 
| 134 | 
            +
                            )
         | 
| 135 | 
            +
                            for i in range(num_layers)
         | 
| 136 | 
            +
                        ]
         | 
| 137 | 
            +
                    )
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
         | 
| 140 | 
            +
                    self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
         | 
| 141 | 
            +
                    self.gradient_checkpointing = use_gradient_checkpointing
         | 
| 142 | 
            +
                    self.patch_size = patch_size
         | 
| 143 | 
            +
                    self.use_flash_attn = use_flash_attn
         | 
| 144 | 
            +
                    self.use_temporal_causal = use_temporal_causal
         | 
| 145 | 
            +
                    self.pos_embed_type = pos_embed_type
         | 
| 146 | 
            +
                    self.temp_pos_embed_type = temp_pos_embed_type
         | 
| 147 | 
            +
                    self.add_temp_pos_embed = add_temp_pos_embed
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    if self.use_temporal_causal:
         | 
| 150 | 
            +
                        print("Using temporal causal attention")
         | 
| 151 | 
            +
                        assert self.use_flash_attn is False, "The flash attention does not support temporal causal"
         | 
| 152 | 
            +
                    
         | 
| 153 | 
            +
                    if interp_condition_pos:
         | 
| 154 | 
            +
                        print("We interp the position embedding of condition latents")
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    # init weights
         | 
| 157 | 
            +
                    self.initialize_weights()
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                def initialize_weights(self):
         | 
| 160 | 
            +
                    # Initialize transformer layers:
         | 
| 161 | 
            +
                    def _basic_init(module):
         | 
| 162 | 
            +
                        if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)):
         | 
| 163 | 
            +
                            torch.nn.init.xavier_uniform_(module.weight)
         | 
| 164 | 
            +
                            if module.bias is not None:
         | 
| 165 | 
            +
                                nn.init.constant_(module.bias, 0)
         | 
| 166 | 
            +
                    self.apply(_basic_init)
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
         | 
| 169 | 
            +
                    w = self.pos_embed.proj.weight.data
         | 
| 170 | 
            +
                    nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
         | 
| 171 | 
            +
                    nn.init.constant_(self.pos_embed.proj.bias, 0)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    # Initialize all the conditioning to normal init
         | 
| 174 | 
            +
                    nn.init.normal_(self.time_text_embed.timestep_embedder.linear_1.weight, std=0.02)
         | 
| 175 | 
            +
                    nn.init.normal_(self.time_text_embed.timestep_embedder.linear_2.weight, std=0.02)
         | 
| 176 | 
            +
                    nn.init.normal_(self.time_text_embed.text_embedder.linear_1.weight, std=0.02)
         | 
| 177 | 
            +
                    nn.init.normal_(self.time_text_embed.text_embedder.linear_2.weight, std=0.02)
         | 
| 178 | 
            +
                    nn.init.normal_(self.context_embedder.weight, std=0.02)
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                    # Zero-out adaLN modulation layers in DiT blocks:
         | 
| 181 | 
            +
                    for block in self.transformer_blocks:
         | 
| 182 | 
            +
                        nn.init.constant_(block.norm1.linear.weight, 0)
         | 
| 183 | 
            +
                        nn.init.constant_(block.norm1.linear.bias, 0)
         | 
| 184 | 
            +
                        nn.init.constant_(block.norm1_context.linear.weight, 0)
         | 
| 185 | 
            +
                        nn.init.constant_(block.norm1_context.linear.bias, 0)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    # Zero-out output layers:
         | 
| 188 | 
            +
                    nn.init.constant_(self.norm_out.linear.weight, 0)
         | 
| 189 | 
            +
                    nn.init.constant_(self.norm_out.linear.bias, 0)
         | 
| 190 | 
            +
                    nn.init.constant_(self.proj_out.weight, 0)
         | 
| 191 | 
            +
                    nn.init.constant_(self.proj_out.bias, 0)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                @torch.no_grad()
         | 
| 194 | 
            +
                def _prepare_latent_image_ids(self, batch_size, temp, height, width, device):
         | 
| 195 | 
            +
                    latent_image_ids = torch.zeros(temp, height, width, 3)
         | 
| 196 | 
            +
                    latent_image_ids[..., 0] = latent_image_ids[..., 0] + torch.arange(temp)[:, None, None]
         | 
| 197 | 
            +
                    latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[None, :, None]
         | 
| 198 | 
            +
                    latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, None, :]
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1, 1)
         | 
| 201 | 
            +
                    latent_image_ids = rearrange(latent_image_ids, 'b t h w c -> b (t h w) c')
         | 
| 202 | 
            +
                    return latent_image_ids.to(device=device)
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                @torch.no_grad()
         | 
| 205 | 
            +
                def _prepare_pyramid_latent_image_ids(self, batch_size, temp_list, height_list, width_list, device):
         | 
| 206 | 
            +
                    base_width = width_list[-1]; base_height = height_list[-1]
         | 
| 207 | 
            +
                    assert base_width == max(width_list)
         | 
| 208 | 
            +
                    assert base_height == max(height_list)
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    image_ids_list = []
         | 
| 211 | 
            +
                    for temp, height, width in zip(temp_list, height_list, width_list):
         | 
| 212 | 
            +
                        latent_image_ids = torch.zeros(temp, height, width, 3)
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                        if height != base_height:
         | 
| 215 | 
            +
                            height_pos = F.interpolate(torch.arange(base_height)[None, None, :].float(), height, mode='linear').squeeze(0, 1)
         | 
| 216 | 
            +
                        else:
         | 
| 217 | 
            +
                            height_pos = torch.arange(base_height).float()
         | 
| 218 | 
            +
                        if width != base_width:
         | 
| 219 | 
            +
                            width_pos = F.interpolate(torch.arange(base_width)[None, None, :].float(), width, mode='linear').squeeze(0, 1)
         | 
| 220 | 
            +
                        else:
         | 
| 221 | 
            +
                            width_pos = torch.arange(base_width).float()
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                        latent_image_ids[..., 0] = latent_image_ids[..., 0] + torch.arange(temp)[:, None, None]  
         | 
| 224 | 
            +
                        latent_image_ids[..., 1] = latent_image_ids[..., 1] + height_pos[None, :, None]
         | 
| 225 | 
            +
                        latent_image_ids[..., 2] = latent_image_ids[..., 2] + width_pos[None, None, :]
         | 
| 226 | 
            +
                        latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1, 1)
         | 
| 227 | 
            +
                        latent_image_ids = rearrange(latent_image_ids, 'b t h w c -> b (t h w) c').to(device)
         | 
| 228 | 
            +
                        image_ids_list.append(latent_image_ids)
         | 
| 229 | 
            +
                
         | 
| 230 | 
            +
                    return image_ids_list
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                @torch.no_grad()
         | 
| 233 | 
            +
                def _prepare_temporal_rope_ids(self, batch_size, temp, height, width, device, start_time_stamp=0):
         | 
| 234 | 
            +
                    latent_image_ids = torch.zeros(temp, height, width, 1)
         | 
| 235 | 
            +
                    latent_image_ids[..., 0] = latent_image_ids[..., 0] + torch.arange(start_time_stamp, start_time_stamp + temp)[:, None, None]
         | 
| 236 | 
            +
                    latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1, 1)
         | 
| 237 | 
            +
                    latent_image_ids = rearrange(latent_image_ids, 'b t h w c -> b (t h w) c')
         | 
| 238 | 
            +
                    return latent_image_ids.to(device=device)
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                @torch.no_grad()
         | 
| 241 | 
            +
                def _prepare_pyramid_temporal_rope_ids(self, sample, batch_size, device):
         | 
| 242 | 
            +
                    image_ids_list = []
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                    for i_b, sample_ in enumerate(sample):
         | 
| 245 | 
            +
                        if not isinstance(sample_, list):
         | 
| 246 | 
            +
                            sample_ = [sample_]
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                        cur_image_ids = []
         | 
| 249 | 
            +
                        start_time_stamp = 0
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                        for clip_ in sample_:
         | 
| 252 | 
            +
                            _, _, temp, height, width = clip_.shape
         | 
| 253 | 
            +
                            height = height // self.patch_size
         | 
| 254 | 
            +
                            width = width // self.patch_size
         | 
| 255 | 
            +
                            cur_image_ids.append(self._prepare_temporal_rope_ids(batch_size, temp, height, width, device, start_time_stamp=start_time_stamp))
         | 
| 256 | 
            +
                            start_time_stamp += temp
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                        cur_image_ids = torch.cat(cur_image_ids, dim=1)
         | 
| 259 | 
            +
                        image_ids_list.append(cur_image_ids)
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                    return image_ids_list
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                def merge_input(self, sample, encoder_hidden_length, encoder_attention_mask):
         | 
| 264 | 
            +
                    """
         | 
| 265 | 
            +
                        Merge the input video with different resolutions into one sequence
         | 
| 266 | 
            +
                        Sample: From low resolution to high resolution
         | 
| 267 | 
            +
                    """
         | 
| 268 | 
            +
                    if isinstance(sample[0], list):
         | 
| 269 | 
            +
                        device = sample[0][-1].device
         | 
| 270 | 
            +
                        pad_batch_size = sample[0][-1].shape[0]
         | 
| 271 | 
            +
                    else:
         | 
| 272 | 
            +
                        device = sample[0].device
         | 
| 273 | 
            +
                        pad_batch_size = sample[0].shape[0]
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                    num_stages = len(sample)
         | 
| 276 | 
            +
                    height_list = [];width_list = [];temp_list = []
         | 
| 277 | 
            +
                    trainable_token_list = []
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                    for i_b, sample_ in enumerate(sample):
         | 
| 280 | 
            +
                        if isinstance(sample_, list):
         | 
| 281 | 
            +
                            sample_ = sample_[-1]
         | 
| 282 | 
            +
                        _, _, temp, height, width = sample_.shape
         | 
| 283 | 
            +
                        height = height // self.patch_size
         | 
| 284 | 
            +
                        width = width // self.patch_size
         | 
| 285 | 
            +
                        temp_list.append(temp)
         | 
| 286 | 
            +
                        height_list.append(height)
         | 
| 287 | 
            +
                        width_list.append(width)
         | 
| 288 | 
            +
                        trainable_token_list.append(height * width * temp)
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                    # prepare the RoPE embedding if needed
         | 
| 291 | 
            +
                    if self.pos_embed_type == 'rope':
         | 
| 292 | 
            +
                        # TODO: support the 3D Rope for video
         | 
| 293 | 
            +
                        raise NotImplementedError("Not compatible with video generation now")
         | 
| 294 | 
            +
                        text_ids = torch.zeros(pad_batch_size, encoder_hidden_length, 3).to(device=device)
         | 
| 295 | 
            +
                        image_ids_list = self._prepare_pyramid_latent_image_ids(pad_batch_size, temp_list, height_list, width_list, device)
         | 
| 296 | 
            +
                        input_ids_list = [torch.cat([text_ids, image_ids], dim=1) for image_ids in image_ids_list]
         | 
| 297 | 
            +
                        image_rotary_emb = [self.rope_embed(input_ids) for input_ids in input_ids_list]  # [bs, seq_len, 1, head_dim // 2, 2, 2]
         | 
| 298 | 
            +
                    else:
         | 
| 299 | 
            +
                        if self.temp_pos_embed_type == 'rope' and self.add_temp_pos_embed:
         | 
| 300 | 
            +
                            image_ids_list = self._prepare_pyramid_temporal_rope_ids(sample, pad_batch_size, device)
         | 
| 301 | 
            +
                            text_ids = torch.zeros(pad_batch_size, encoder_attention_mask.shape[1], 1).to(device=device)    
         | 
| 302 | 
            +
                            input_ids_list = [torch.cat([text_ids, image_ids], dim=1) for image_ids in image_ids_list]
         | 
| 303 | 
            +
                            image_rotary_emb = [self.temp_rope_embed(input_ids) for input_ids in input_ids_list]  # [bs, seq_len, 1, head_dim // 2, 2, 2]
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                            if is_sequence_parallel_initialized():
         | 
| 306 | 
            +
                                sp_group = get_sequence_parallel_group()
         | 
| 307 | 
            +
                                sp_group_size = get_sequence_parallel_world_size()
         | 
| 308 | 
            +
                                image_rotary_emb = [all_to_all(x_.repeat(1, 1, sp_group_size, 1, 1, 1), sp_group, sp_group_size, scatter_dim=2, gather_dim=0) for x_ in image_rotary_emb]
         | 
| 309 | 
            +
                                input_ids_list = [all_to_all(input_ids.repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0) for input_ids in input_ids_list]
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                        else:
         | 
| 312 | 
            +
                            image_rotary_emb = None
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                    hidden_states = self.pos_embed(sample)  # hidden states is a list of [b c t h w] b = real_b // num_stages
         | 
| 315 | 
            +
                    hidden_length = []
         | 
| 316 | 
            +
                
         | 
| 317 | 
            +
                    for i_b in range(num_stages):
         | 
| 318 | 
            +
                        hidden_length.append(hidden_states[i_b].shape[1])
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    # prepare the attention mask
         | 
| 321 | 
            +
                    if self.use_flash_attn:
         | 
| 322 | 
            +
                        attention_mask = None
         | 
| 323 | 
            +
                        indices_list = []
         | 
| 324 | 
            +
                        for i_p, length in enumerate(hidden_length):
         | 
| 325 | 
            +
                            pad_attention_mask = torch.ones((pad_batch_size, length), dtype=encoder_attention_mask.dtype).to(device)
         | 
| 326 | 
            +
                            pad_attention_mask = torch.cat([encoder_attention_mask[i_p::num_stages], pad_attention_mask], dim=1)
         | 
| 327 | 
            +
                            
         | 
| 328 | 
            +
                            if is_sequence_parallel_initialized():
         | 
| 329 | 
            +
                                sp_group = get_sequence_parallel_group()
         | 
| 330 | 
            +
                                sp_group_size = get_sequence_parallel_world_size()
         | 
| 331 | 
            +
                                pad_attention_mask = all_to_all(pad_attention_mask.unsqueeze(2).repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0)
         | 
| 332 | 
            +
                                pad_attention_mask = pad_attention_mask.squeeze(2)
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                            seqlens_in_batch = pad_attention_mask.sum(dim=-1, dtype=torch.int32)
         | 
| 335 | 
            +
                            indices = torch.nonzero(pad_attention_mask.flatten(), as_tuple=False).flatten()
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                            indices_list.append(
         | 
| 338 | 
            +
                                {
         | 
| 339 | 
            +
                                    'indices': indices,
         | 
| 340 | 
            +
                                    'seqlens_in_batch': seqlens_in_batch,
         | 
| 341 | 
            +
                                }
         | 
| 342 | 
            +
                            )
         | 
| 343 | 
            +
                        encoder_attention_mask = indices_list
         | 
| 344 | 
            +
                    else:
         | 
| 345 | 
            +
                        assert encoder_attention_mask.shape[1] == encoder_hidden_length
         | 
| 346 | 
            +
                        real_batch_size = encoder_attention_mask.shape[0]
         | 
| 347 | 
            +
                        # prepare text ids
         | 
| 348 | 
            +
                        text_ids = torch.arange(1, real_batch_size + 1, dtype=encoder_attention_mask.dtype).unsqueeze(1).repeat(1, encoder_hidden_length)
         | 
| 349 | 
            +
                        text_ids = text_ids.to(device)
         | 
| 350 | 
            +
                        text_ids[encoder_attention_mask == 0] = 0
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                        # prepare image ids
         | 
| 353 | 
            +
                        image_ids = torch.arange(1, real_batch_size + 1, dtype=encoder_attention_mask.dtype).unsqueeze(1).repeat(1, max(hidden_length))
         | 
| 354 | 
            +
                        image_ids = image_ids.to(device)
         | 
| 355 | 
            +
                        image_ids_list = []
         | 
| 356 | 
            +
                        for i_p, length in enumerate(hidden_length):
         | 
| 357 | 
            +
                            image_ids_list.append(image_ids[i_p::num_stages][:, :length])
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                        if is_sequence_parallel_initialized():
         | 
| 360 | 
            +
                            sp_group = get_sequence_parallel_group()
         | 
| 361 | 
            +
                            sp_group_size = get_sequence_parallel_world_size()
         | 
| 362 | 
            +
                            text_ids = all_to_all(text_ids.unsqueeze(2).repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0).squeeze(2)
         | 
| 363 | 
            +
                            image_ids_list = [all_to_all(image_ids_.unsqueeze(2).repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0).squeeze(2) for image_ids_ in image_ids_list]
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                        attention_mask = []
         | 
| 366 | 
            +
                        for i_p in range(len(hidden_length)):
         | 
| 367 | 
            +
                            image_ids = image_ids_list[i_p]
         | 
| 368 | 
            +
                            token_ids = torch.cat([text_ids[i_p::num_stages], image_ids], dim=1)
         | 
| 369 | 
            +
                            stage_attention_mask = rearrange(token_ids, 'b i -> b 1 i 1') == rearrange(token_ids, 'b j -> b 1 1 j')  # [bs, 1, q_len, k_len]
         | 
| 370 | 
            +
                            if self.use_temporal_causal:
         | 
| 371 | 
            +
                                input_order_ids = input_ids_list[i_p].squeeze(2)
         | 
| 372 | 
            +
                                temporal_causal_mask = rearrange(input_order_ids, 'b i -> b 1 i 1') >= rearrange(input_order_ids, 'b j -> b 1 1 j')
         | 
| 373 | 
            +
                                stage_attention_mask = stage_attention_mask & temporal_causal_mask
         | 
| 374 | 
            +
                            attention_mask.append(stage_attention_mask)
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                    return hidden_states, hidden_length, temp_list, height_list, width_list, trainable_token_list, encoder_attention_mask, attention_mask, image_rotary_emb
         | 
| 377 | 
            +
             | 
| 378 | 
            +
                def split_output(self, batch_hidden_states, hidden_length, temps, heights, widths, trainable_token_list):
         | 
| 379 | 
            +
                    # To split the hidden states
         | 
| 380 | 
            +
                    batch_size = batch_hidden_states.shape[0]
         | 
| 381 | 
            +
                    output_hidden_list = []
         | 
| 382 | 
            +
                    batch_hidden_states = torch.split(batch_hidden_states, hidden_length, dim=1)
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                    if is_sequence_parallel_initialized():
         | 
| 385 | 
            +
                        sp_group_size = get_sequence_parallel_world_size()
         | 
| 386 | 
            +
                        batch_size = batch_size // sp_group_size
         | 
| 387 | 
            +
             | 
| 388 | 
            +
                    for i_p, length in enumerate(hidden_length):
         | 
| 389 | 
            +
                        width, height, temp = widths[i_p], heights[i_p], temps[i_p]
         | 
| 390 | 
            +
                        trainable_token_num = trainable_token_list[i_p]
         | 
| 391 | 
            +
                        hidden_states = batch_hidden_states[i_p]
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                        if is_sequence_parallel_initialized():
         | 
| 394 | 
            +
                            sp_group = get_sequence_parallel_group()
         | 
| 395 | 
            +
                            sp_group_size = get_sequence_parallel_world_size()
         | 
| 396 | 
            +
                            hidden_states = all_to_all(hidden_states, sp_group, sp_group_size, scatter_dim=0, gather_dim=1)
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                        # only the trainable token are taking part in loss computation
         | 
| 399 | 
            +
                        hidden_states = hidden_states[:, -trainable_token_num:]
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                        # unpatchify
         | 
| 402 | 
            +
                        hidden_states = hidden_states.reshape(
         | 
| 403 | 
            +
                            shape=(batch_size, temp, height, width, self.patch_size, self.patch_size, self.out_channels)
         | 
| 404 | 
            +
                        )
         | 
| 405 | 
            +
                        hidden_states = rearrange(hidden_states, "b t h w p1 p2 c -> b t (h p1) (w p2) c")
         | 
| 406 | 
            +
                        hidden_states = rearrange(hidden_states, "b t h w c -> b c t h w")
         | 
| 407 | 
            +
                        output_hidden_list.append(hidden_states)
         | 
| 408 | 
            +
             | 
| 409 | 
            +
                    return output_hidden_list
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                def forward(
         | 
| 412 | 
            +
                    self,
         | 
| 413 | 
            +
                    sample: torch.FloatTensor, # [num_stages]
         | 
| 414 | 
            +
                    encoder_hidden_states: torch.FloatTensor = None,
         | 
| 415 | 
            +
                    encoder_attention_mask: torch.FloatTensor = None,
         | 
| 416 | 
            +
                    pooled_projections: torch.FloatTensor = None,
         | 
| 417 | 
            +
                    timestep_ratio: torch.FloatTensor = None,
         | 
| 418 | 
            +
                ):
         | 
| 419 | 
            +
                    # Get the timestep embedding
         | 
| 420 | 
            +
                    temb = self.time_text_embed(timestep_ratio, pooled_projections)
         | 
| 421 | 
            +
                    encoder_hidden_states = self.context_embedder(encoder_hidden_states)
         | 
| 422 | 
            +
                    encoder_hidden_length = encoder_hidden_states.shape[1]
         | 
| 423 | 
            +
             | 
| 424 | 
            +
                    # Get the input sequence
         | 
| 425 | 
            +
                    hidden_states, hidden_length, temps, heights, widths, trainable_token_list, encoder_attention_mask, \
         | 
| 426 | 
            +
                            attention_mask, image_rotary_emb = self.merge_input(sample, encoder_hidden_length, encoder_attention_mask)
         | 
| 427 | 
            +
                    
         | 
| 428 | 
            +
                    # split the long latents if necessary
         | 
| 429 | 
            +
                    if is_sequence_parallel_initialized():
         | 
| 430 | 
            +
                        sp_group = get_sequence_parallel_group()
         | 
| 431 | 
            +
                        sp_group_size = get_sequence_parallel_world_size()
         | 
| 432 | 
            +
                        
         | 
| 433 | 
            +
                        # sync the input hidden states
         | 
| 434 | 
            +
                        batch_hidden_states = []
         | 
| 435 | 
            +
                        for i_p, hidden_states_ in enumerate(hidden_states):
         | 
| 436 | 
            +
                            assert hidden_states_.shape[1] % sp_group_size == 0, "The sequence length should be divided by sequence parallel size"
         | 
| 437 | 
            +
                            hidden_states_ = all_to_all(hidden_states_, sp_group, sp_group_size, scatter_dim=1, gather_dim=0)
         | 
| 438 | 
            +
                            hidden_length[i_p] = hidden_length[i_p] // sp_group_size
         | 
| 439 | 
            +
                            batch_hidden_states.append(hidden_states_)
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                        # sync the encoder hidden states
         | 
| 442 | 
            +
                        hidden_states = torch.cat(batch_hidden_states, dim=1)
         | 
| 443 | 
            +
                        encoder_hidden_states = all_to_all(encoder_hidden_states, sp_group, sp_group_size, scatter_dim=1, gather_dim=0)
         | 
| 444 | 
            +
                        temb = all_to_all(temb.unsqueeze(1).repeat(1, sp_group_size, 1), sp_group, sp_group_size, scatter_dim=1, gather_dim=0)
         | 
| 445 | 
            +
                        temb = temb.squeeze(1)
         | 
| 446 | 
            +
                    else:
         | 
| 447 | 
            +
                        hidden_states = torch.cat(hidden_states, dim=1)
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                    # print(hidden_length)
         | 
| 450 | 
            +
                    for i_b, block in enumerate(self.transformer_blocks):
         | 
| 451 | 
            +
                        if self.training and self.gradient_checkpointing and (i_b >= 2):
         | 
| 452 | 
            +
                            def create_custom_forward(module):
         | 
| 453 | 
            +
                                def custom_forward(*inputs):
         | 
| 454 | 
            +
                                    return module(*inputs)
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                                return custom_forward
         | 
| 457 | 
            +
             | 
| 458 | 
            +
                            ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
         | 
| 459 | 
            +
                            encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
         | 
| 460 | 
            +
                                create_custom_forward(block),
         | 
| 461 | 
            +
                                hidden_states,
         | 
| 462 | 
            +
                                encoder_hidden_states,
         | 
| 463 | 
            +
                                encoder_attention_mask,
         | 
| 464 | 
            +
                                temb,
         | 
| 465 | 
            +
                                attention_mask,
         | 
| 466 | 
            +
                                hidden_length,
         | 
| 467 | 
            +
                                image_rotary_emb,
         | 
| 468 | 
            +
                                **ckpt_kwargs,
         | 
| 469 | 
            +
                            )
         | 
| 470 | 
            +
             | 
| 471 | 
            +
                        else:
         | 
| 472 | 
            +
                            encoder_hidden_states, hidden_states = block(
         | 
| 473 | 
            +
                                hidden_states=hidden_states, 
         | 
| 474 | 
            +
                                encoder_hidden_states=encoder_hidden_states,
         | 
| 475 | 
            +
                                encoder_attention_mask=encoder_attention_mask,
         | 
| 476 | 
            +
                                temb=temb,
         | 
| 477 | 
            +
                                attention_mask=attention_mask,
         | 
| 478 | 
            +
                                hidden_length=hidden_length,
         | 
| 479 | 
            +
                                image_rotary_emb=image_rotary_emb,
         | 
| 480 | 
            +
                            )
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                    hidden_states = self.norm_out(hidden_states, temb, hidden_length=hidden_length)
         | 
| 483 | 
            +
                    hidden_states = self.proj_out(hidden_states)
         | 
| 484 | 
            +
             | 
| 485 | 
            +
                    output = self.split_output(hidden_states, hidden_length, temps, heights, widths, trainable_token_list)
         | 
| 486 | 
            +
             | 
| 487 | 
            +
                    return output
         | 
    	
        pyramid_dit/modeling_text_encoder.py
    ADDED
    
    | @@ -0,0 +1,140 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from transformers import (
         | 
| 6 | 
            +
                CLIPTextModelWithProjection,
         | 
| 7 | 
            +
                CLIPTokenizer,
         | 
| 8 | 
            +
                T5EncoderModel,
         | 
| 9 | 
            +
                T5TokenizerFast,
         | 
| 10 | 
            +
            )
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from typing import Any, Callable, Dict, List, Optional, Union
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            class SD3TextEncoderWithMask(nn.Module):
         | 
| 16 | 
            +
                def __init__(self, model_path, torch_dtype):
         | 
| 17 | 
            +
                    super().__init__()
         | 
| 18 | 
            +
                    # CLIP-L
         | 
| 19 | 
            +
                    self.tokenizer = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer'))
         | 
| 20 | 
            +
                    self.tokenizer_max_length = self.tokenizer.model_max_length
         | 
| 21 | 
            +
                    self.text_encoder = CLIPTextModelWithProjection.from_pretrained(os.path.join(model_path, 'text_encoder'), torch_dtype=torch_dtype)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    # CLIP-G
         | 
| 24 | 
            +
                    self.tokenizer_2 = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer_2'))
         | 
| 25 | 
            +
                    self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(os.path.join(model_path, 'text_encoder_2'), torch_dtype=torch_dtype)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    # T5
         | 
| 28 | 
            +
                    self.tokenizer_3 = T5TokenizerFast.from_pretrained(os.path.join(model_path, 'tokenizer_3'))
         | 
| 29 | 
            +
                    self.text_encoder_3 = T5EncoderModel.from_pretrained(os.path.join(model_path, 'text_encoder_3'), torch_dtype=torch_dtype)
         | 
| 30 | 
            +
                
         | 
| 31 | 
            +
                    self._freeze()
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                def _freeze(self):
         | 
| 34 | 
            +
                    for param in self.parameters():
         | 
| 35 | 
            +
                        param.requires_grad = False
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                def _get_t5_prompt_embeds(
         | 
| 38 | 
            +
                    self,
         | 
| 39 | 
            +
                    prompt: Union[str, List[str]] = None,
         | 
| 40 | 
            +
                    num_images_per_prompt: int = 1,
         | 
| 41 | 
            +
                    device: Optional[torch.device] = None,
         | 
| 42 | 
            +
                    max_sequence_length: int = 128,
         | 
| 43 | 
            +
                ):
         | 
| 44 | 
            +
                    prompt = [prompt] if isinstance(prompt, str) else prompt
         | 
| 45 | 
            +
                    batch_size = len(prompt)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    text_inputs = self.tokenizer_3(
         | 
| 48 | 
            +
                        prompt,
         | 
| 49 | 
            +
                        padding="max_length",
         | 
| 50 | 
            +
                        max_length=max_sequence_length,
         | 
| 51 | 
            +
                        truncation=True,
         | 
| 52 | 
            +
                        add_special_tokens=True,
         | 
| 53 | 
            +
                        return_tensors="pt",
         | 
| 54 | 
            +
                    )
         | 
| 55 | 
            +
                    text_input_ids = text_inputs.input_ids
         | 
| 56 | 
            +
                    prompt_attention_mask = text_inputs.attention_mask
         | 
| 57 | 
            +
                    prompt_attention_mask = prompt_attention_mask.to(device)
         | 
| 58 | 
            +
                    prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
         | 
| 59 | 
            +
                    dtype = self.text_encoder_3.dtype
         | 
| 60 | 
            +
                    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    _, seq_len, _ = prompt_embeds.shape
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
         | 
| 65 | 
            +
                    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
         | 
| 66 | 
            +
                    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
         | 
| 67 | 
            +
                    prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
         | 
| 68 | 
            +
                    prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    return prompt_embeds, prompt_attention_mask
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def _get_clip_prompt_embeds(
         | 
| 73 | 
            +
                    self,
         | 
| 74 | 
            +
                    prompt: Union[str, List[str]],
         | 
| 75 | 
            +
                    num_images_per_prompt: int = 1,
         | 
| 76 | 
            +
                    device: Optional[torch.device] = None,
         | 
| 77 | 
            +
                    clip_skip: Optional[int] = None,
         | 
| 78 | 
            +
                    clip_model_index: int = 0,
         | 
| 79 | 
            +
                ):
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    clip_tokenizers = [self.tokenizer, self.tokenizer_2]
         | 
| 82 | 
            +
                    clip_text_encoders = [self.text_encoder, self.text_encoder_2]
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    tokenizer = clip_tokenizers[clip_model_index]
         | 
| 85 | 
            +
                    text_encoder = clip_text_encoders[clip_model_index]
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    batch_size = len(prompt)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    text_inputs = tokenizer(
         | 
| 90 | 
            +
                        prompt,
         | 
| 91 | 
            +
                        padding="max_length",
         | 
| 92 | 
            +
                        max_length=self.tokenizer_max_length,
         | 
| 93 | 
            +
                        truncation=True,
         | 
| 94 | 
            +
                        return_tensors="pt",
         | 
| 95 | 
            +
                    )
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    text_input_ids = text_inputs.input_ids
         | 
| 98 | 
            +
                    prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
         | 
| 99 | 
            +
                    pooled_prompt_embeds = prompt_embeds[0]
         | 
| 100 | 
            +
                    pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
         | 
| 101 | 
            +
                    pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    return pooled_prompt_embeds
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                def encode_prompt(self, 
         | 
| 106 | 
            +
                    prompt, 
         | 
| 107 | 
            +
                    num_images_per_prompt=1, 
         | 
| 108 | 
            +
                    clip_skip: Optional[int] = None,
         | 
| 109 | 
            +
                    device=None,
         | 
| 110 | 
            +
                ):
         | 
| 111 | 
            +
                    prompt = [prompt] if isinstance(prompt, str) else prompt
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    pooled_prompt_embed = self._get_clip_prompt_embeds(
         | 
| 114 | 
            +
                        prompt=prompt,
         | 
| 115 | 
            +
                        device=device,
         | 
| 116 | 
            +
                        num_images_per_prompt=num_images_per_prompt,
         | 
| 117 | 
            +
                        clip_skip=clip_skip,
         | 
| 118 | 
            +
                        clip_model_index=0,
         | 
| 119 | 
            +
                    )
         | 
| 120 | 
            +
                    pooled_prompt_2_embed = self._get_clip_prompt_embeds(
         | 
| 121 | 
            +
                        prompt=prompt,
         | 
| 122 | 
            +
                        device=device,
         | 
| 123 | 
            +
                        num_images_per_prompt=num_images_per_prompt,
         | 
| 124 | 
            +
                        clip_skip=clip_skip,
         | 
| 125 | 
            +
                        clip_model_index=1,
         | 
| 126 | 
            +
                    )
         | 
| 127 | 
            +
                    pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
         | 
| 130 | 
            +
                        prompt=prompt,
         | 
| 131 | 
            +
                        num_images_per_prompt=num_images_per_prompt,
         | 
| 132 | 
            +
                        device=device,
         | 
| 133 | 
            +
                    )
         | 
| 134 | 
            +
                    return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                def forward(self, input_prompts, device):
         | 
| 137 | 
            +
                    with torch.no_grad():
         | 
| 138 | 
            +
                        prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.encode_prompt(input_prompts, 1, clip_skip=None, device=device)
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
         | 
    	
        pyramid_dit/pyramid_dit_for_video_gen_pipeline.py
    ADDED
    
    | @@ -0,0 +1,672 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import sys
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            import torch.nn.functional as F
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from collections import OrderedDict
         | 
| 8 | 
            +
            from einops import rearrange
         | 
| 9 | 
            +
            from diffusers.utils.torch_utils import randn_tensor
         | 
| 10 | 
            +
            import numpy as np
         | 
| 11 | 
            +
            import math
         | 
| 12 | 
            +
            import random
         | 
| 13 | 
            +
            import PIL
         | 
| 14 | 
            +
            from PIL import Image
         | 
| 15 | 
            +
            from tqdm import tqdm
         | 
| 16 | 
            +
            from torchvision import transforms
         | 
| 17 | 
            +
            from copy import deepcopy
         | 
| 18 | 
            +
            from typing import Any, Callable, Dict, List, Optional, Union
         | 
| 19 | 
            +
            from accelerate import Accelerator
         | 
| 20 | 
            +
            from diffusion_schedulers import PyramidFlowMatchEulerDiscreteScheduler
         | 
| 21 | 
            +
            from video_vae.modeling_causal_vae import CausalVideoVAE
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            from trainer_misc import (
         | 
| 24 | 
            +
                all_to_all,
         | 
| 25 | 
            +
                is_sequence_parallel_initialized,
         | 
| 26 | 
            +
                get_sequence_parallel_group,
         | 
| 27 | 
            +
                get_sequence_parallel_group_rank,
         | 
| 28 | 
            +
                get_sequence_parallel_rank,
         | 
| 29 | 
            +
                get_sequence_parallel_world_size,
         | 
| 30 | 
            +
                get_rank,
         | 
| 31 | 
            +
            )
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            from .modeling_pyramid_mmdit import PyramidDiffusionMMDiT
         | 
| 34 | 
            +
            from .modeling_text_encoder import SD3TextEncoderWithMask
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def compute_density_for_timestep_sampling(
         | 
| 38 | 
            +
                weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
         | 
| 39 | 
            +
            ):
         | 
| 40 | 
            +
                if weighting_scheme == "logit_normal":
         | 
| 41 | 
            +
                    # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
         | 
| 42 | 
            +
                    u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
         | 
| 43 | 
            +
                    u = torch.nn.functional.sigmoid(u)
         | 
| 44 | 
            +
                elif weighting_scheme == "mode":
         | 
| 45 | 
            +
                    u = torch.rand(size=(batch_size,), device="cpu")
         | 
| 46 | 
            +
                    u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
         | 
| 47 | 
            +
                else:
         | 
| 48 | 
            +
                    u = torch.rand(size=(batch_size,), device="cpu")
         | 
| 49 | 
            +
                return u
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            class PyramidDiTForVideoGeneration:
         | 
| 53 | 
            +
                """
         | 
| 54 | 
            +
                    The pyramid dit for both image and video generation, The running class wrapper
         | 
| 55 | 
            +
                    This class is mainly for fixed unit implementation: 1 + n + n + n
         | 
| 56 | 
            +
                """
         | 
| 57 | 
            +
                def __init__(self, model_path, model_dtype='bf16', use_gradient_checkpointing=False, return_log=True,
         | 
| 58 | 
            +
                    model_variant="diffusion_transformer_768p", timestep_shift=1.0, stage_range=[0, 1/3, 2/3, 1],
         | 
| 59 | 
            +
                    sample_ratios=[1, 1, 1], scheduler_gamma=1/3, use_mixed_training=False, use_flash_attn=False, 
         | 
| 60 | 
            +
                    load_text_encoder=True, load_vae=True, max_temporal_length=31, frame_per_unit=1, use_temporal_causal=True, 
         | 
| 61 | 
            +
                    corrupt_ratio=1/3, interp_condition_pos=True, stages=[1, 2, 4], **kwargs,
         | 
| 62 | 
            +
                ):
         | 
| 63 | 
            +
                    super().__init__()
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    if model_dtype == 'bf16':
         | 
| 66 | 
            +
                        torch_dtype = torch.bfloat16
         | 
| 67 | 
            +
                    elif model_dtype == 'fp16':
         | 
| 68 | 
            +
                        torch_dtype = torch.float16
         | 
| 69 | 
            +
                    else:
         | 
| 70 | 
            +
                        torch_dtype = torch.float32
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    self.stages = stages
         | 
| 73 | 
            +
                    self.sample_ratios = sample_ratios
         | 
| 74 | 
            +
                    self.corrupt_ratio = corrupt_ratio
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    dit_path = os.path.join(model_path, model_variant)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    # The dit
         | 
| 79 | 
            +
                    if use_mixed_training:
         | 
| 80 | 
            +
                        print("using mixed precision training, do not explicitly casting models")
         | 
| 81 | 
            +
                        self.dit = PyramidDiffusionMMDiT.from_pretrained(
         | 
| 82 | 
            +
                            dit_path, use_gradient_checkpointing=use_gradient_checkpointing, 
         | 
| 83 | 
            +
                            use_flash_attn=use_flash_attn, use_t5_mask=True, 
         | 
| 84 | 
            +
                            add_temp_pos_embed=True, temp_pos_embed_type='rope', 
         | 
| 85 | 
            +
                            use_temporal_causal=use_temporal_causal, interp_condition_pos=interp_condition_pos,
         | 
| 86 | 
            +
                        )
         | 
| 87 | 
            +
                    else:
         | 
| 88 | 
            +
                        print("using half precision")
         | 
| 89 | 
            +
                        self.dit = PyramidDiffusionMMDiT.from_pretrained(
         | 
| 90 | 
            +
                            dit_path, torch_dtype=torch_dtype, 
         | 
| 91 | 
            +
                            use_gradient_checkpointing=use_gradient_checkpointing, 
         | 
| 92 | 
            +
                            use_flash_attn=use_flash_attn, use_t5_mask=True,
         | 
| 93 | 
            +
                            add_temp_pos_embed=True, temp_pos_embed_type='rope', 
         | 
| 94 | 
            +
                            use_temporal_causal=use_temporal_causal, interp_condition_pos=interp_condition_pos,
         | 
| 95 | 
            +
                        )
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    # The text encoder
         | 
| 98 | 
            +
                    if load_text_encoder:
         | 
| 99 | 
            +
                        self.text_encoder = SD3TextEncoderWithMask(model_path, torch_dtype=torch_dtype)
         | 
| 100 | 
            +
                    else:
         | 
| 101 | 
            +
                        self.text_encoder = None
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    # The base video vae decoder
         | 
| 104 | 
            +
                    if load_vae:
         | 
| 105 | 
            +
                        self.vae = CausalVideoVAE.from_pretrained(os.path.join(model_path, 'causal_video_vae'), torch_dtype=torch_dtype, interpolate=False)
         | 
| 106 | 
            +
                        # Freeze vae
         | 
| 107 | 
            +
                        for parameter in self.vae.parameters():
         | 
| 108 | 
            +
                            parameter.requires_grad = False
         | 
| 109 | 
            +
                    else:
         | 
| 110 | 
            +
                        self.vae = None
         | 
| 111 | 
            +
                    
         | 
| 112 | 
            +
                    # For the image latent
         | 
| 113 | 
            +
                    self.vae_shift_factor = 0.1490
         | 
| 114 | 
            +
                    self.vae_scale_factor = 1 / 1.8415
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    # For the video latent
         | 
| 117 | 
            +
                    self.vae_video_shift_factor = -0.2343
         | 
| 118 | 
            +
                    self.vae_video_scale_factor = 1 / 3.0986
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    self.downsample = 8
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    # Configure the video training hyper-parameters
         | 
| 123 | 
            +
                    # The video sequence: one frame + N * unit
         | 
| 124 | 
            +
                    self.frame_per_unit = frame_per_unit
         | 
| 125 | 
            +
                    self.max_temporal_length = max_temporal_length
         | 
| 126 | 
            +
                    assert (max_temporal_length - 1) % frame_per_unit == 0, "The frame number should be divided by the frame number per unit"
         | 
| 127 | 
            +
                    self.num_units_per_video = 1 + ((max_temporal_length - 1) // frame_per_unit) + int(sum(sample_ratios))
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    self.scheduler = PyramidFlowMatchEulerDiscreteScheduler(
         | 
| 130 | 
            +
                        shift=timestep_shift, stages=len(self.stages), 
         | 
| 131 | 
            +
                        stage_range=stage_range, gamma=scheduler_gamma,
         | 
| 132 | 
            +
                    )
         | 
| 133 | 
            +
                    print(f"The start sigmas and end sigmas of each stage is Start: {self.scheduler.start_sigmas}, End: {self.scheduler.end_sigmas}, Ori_start: {self.scheduler.ori_start_sigmas}")
         | 
| 134 | 
            +
                    
         | 
| 135 | 
            +
                    self.cfg_rate = 0.1
         | 
| 136 | 
            +
                    self.return_log = return_log
         | 
| 137 | 
            +
                    self.use_flash_attn = use_flash_attn
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                def load_checkpoint(self, checkpoint_path, model_key='model', **kwargs):
         | 
| 140 | 
            +
                    checkpoint = torch.load(checkpoint_path, map_location='cpu')
         | 
| 141 | 
            +
                    dit_checkpoint = OrderedDict()
         | 
| 142 | 
            +
                    for key in checkpoint:
         | 
| 143 | 
            +
                        if key.startswith('vae') or key.startswith('text_encoder'):
         | 
| 144 | 
            +
                            continue
         | 
| 145 | 
            +
                        if key.startswith('dit'):
         | 
| 146 | 
            +
                            new_key = key.split('.')
         | 
| 147 | 
            +
                            new_key = '.'.join(new_key[1:])
         | 
| 148 | 
            +
                            dit_checkpoint[new_key] = checkpoint[key]
         | 
| 149 | 
            +
                        else:
         | 
| 150 | 
            +
                            dit_checkpoint[key] = checkpoint[key]
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    load_result = self.dit.load_state_dict(dit_checkpoint, strict=True)
         | 
| 153 | 
            +
                    print(f"Load checkpoint from {checkpoint_path}, load result: {load_result}")
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                def load_vae_checkpoint(self, vae_checkpoint_path, model_key='model'):
         | 
| 156 | 
            +
                    checkpoint = torch.load(vae_checkpoint_path, map_location='cpu')
         | 
| 157 | 
            +
                    checkpoint = checkpoint[model_key]
         | 
| 158 | 
            +
                    loaded_checkpoint = OrderedDict()
         | 
| 159 | 
            +
                    
         | 
| 160 | 
            +
                    for key in checkpoint.keys():
         | 
| 161 | 
            +
                        if key.startswith('vae.'):
         | 
| 162 | 
            +
                            new_key = key.split('.')
         | 
| 163 | 
            +
                            new_key = '.'.join(new_key[1:])
         | 
| 164 | 
            +
                            loaded_checkpoint[new_key] = checkpoint[key]
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    load_result = self.vae.load_state_dict(loaded_checkpoint)
         | 
| 167 | 
            +
                    print(f"Load the VAE from {vae_checkpoint_path}, load result: {load_result}")
         | 
| 168 | 
            +
                
         | 
| 169 | 
            +
                @torch.no_grad()
         | 
| 170 | 
            +
                def get_pyramid_latent(self, x, stage_num):
         | 
| 171 | 
            +
                    # x is the origin vae latent
         | 
| 172 | 
            +
                    vae_latent_list = []
         | 
| 173 | 
            +
                    vae_latent_list.append(x)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    temp, height, width = x.shape[-3], x.shape[-2], x.shape[-1]
         | 
| 176 | 
            +
                    for _ in range(stage_num):
         | 
| 177 | 
            +
                        height //= 2
         | 
| 178 | 
            +
                        width //= 2
         | 
| 179 | 
            +
                        x = rearrange(x, 'b c t h w -> (b t) c h w')
         | 
| 180 | 
            +
                        x = torch.nn.functional.interpolate(x, size=(height, width), mode='bilinear')
         | 
| 181 | 
            +
                        x = rearrange(x, '(b t) c h w -> b c t h w', t=temp)
         | 
| 182 | 
            +
                        vae_latent_list.append(x)
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    vae_latent_list = list(reversed(vae_latent_list))
         | 
| 185 | 
            +
                    return vae_latent_list
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                def prepare_latents(
         | 
| 188 | 
            +
                    self,
         | 
| 189 | 
            +
                    batch_size,
         | 
| 190 | 
            +
                    num_channels_latents,
         | 
| 191 | 
            +
                    temp,
         | 
| 192 | 
            +
                    height,
         | 
| 193 | 
            +
                    width,
         | 
| 194 | 
            +
                    dtype,
         | 
| 195 | 
            +
                    device,
         | 
| 196 | 
            +
                    generator,
         | 
| 197 | 
            +
                ):
         | 
| 198 | 
            +
                    shape = (
         | 
| 199 | 
            +
                        batch_size,
         | 
| 200 | 
            +
                        num_channels_latents,
         | 
| 201 | 
            +
                        int(temp),
         | 
| 202 | 
            +
                        int(height) // self.downsample,
         | 
| 203 | 
            +
                        int(width) // self.downsample,
         | 
| 204 | 
            +
                    )
         | 
| 205 | 
            +
                    latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
         | 
| 206 | 
            +
                    return latents
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                def sample_block_noise(self, bs, ch, temp, height, width):
         | 
| 209 | 
            +
                    gamma = self.scheduler.config.gamma
         | 
| 210 | 
            +
                    dist = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(4), torch.eye(4) * (1 + gamma) - torch.ones(4, 4) * gamma)
         | 
| 211 | 
            +
                    block_number = bs * ch * temp * (height // 2) * (width // 2)
         | 
| 212 | 
            +
                    noise = torch.stack([dist.sample() for _ in range(block_number)]) # [block number, 4]
         | 
| 213 | 
            +
                    noise = rearrange(noise, '(b c t h w) (p q) -> b c t (h p) (w q)',b=bs,c=ch,t=temp,h=height//2,w=width//2,p=2,q=2)
         | 
| 214 | 
            +
                    return noise
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                @torch.no_grad()
         | 
| 217 | 
            +
                def generate_one_unit(
         | 
| 218 | 
            +
                    self,
         | 
| 219 | 
            +
                    latents,
         | 
| 220 | 
            +
                    past_conditions, # List of past conditions, contains the conditions of each stage
         | 
| 221 | 
            +
                    prompt_embeds,
         | 
| 222 | 
            +
                    prompt_attention_mask,
         | 
| 223 | 
            +
                    pooled_prompt_embeds,
         | 
| 224 | 
            +
                    num_inference_steps,
         | 
| 225 | 
            +
                    height,
         | 
| 226 | 
            +
                    width,
         | 
| 227 | 
            +
                    temp,
         | 
| 228 | 
            +
                    device,
         | 
| 229 | 
            +
                    dtype,
         | 
| 230 | 
            +
                    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
         | 
| 231 | 
            +
                    is_first_frame: bool = False,
         | 
| 232 | 
            +
                ):
         | 
| 233 | 
            +
                    stages = self.stages
         | 
| 234 | 
            +
                    intermed_latents = []
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    for i_s in range(len(stages)):
         | 
| 237 | 
            +
                        self.scheduler.set_timesteps(num_inference_steps[i_s], i_s, device=device)
         | 
| 238 | 
            +
                        timesteps = self.scheduler.timesteps
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                        if i_s > 0:
         | 
| 241 | 
            +
                            height *= 2; width *= 2
         | 
| 242 | 
            +
                            latents = rearrange(latents, 'b c t h w -> (b t) c h w')
         | 
| 243 | 
            +
                            latents = F.interpolate(latents, size=(height, width), mode='nearest')
         | 
| 244 | 
            +
                            latents = rearrange(latents, '(b t) c h w -> b c t h w', t=temp)
         | 
| 245 | 
            +
                            # Fix the stage
         | 
| 246 | 
            +
                            ori_sigma = 1 - self.scheduler.ori_start_sigmas[i_s]   # the original coeff of signal
         | 
| 247 | 
            +
                            gamma = self.scheduler.config.gamma
         | 
| 248 | 
            +
                            alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)
         | 
| 249 | 
            +
                            beta = alpha * (1 - ori_sigma) / math.sqrt(gamma)
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                            bs, ch, temp, height, width = latents.shape
         | 
| 252 | 
            +
                            noise = self.sample_block_noise(bs, ch, temp, height, width)
         | 
| 253 | 
            +
                            noise = noise.to(device=device, dtype=dtype)
         | 
| 254 | 
            +
                            latents = alpha * latents + beta * noise    # To fix the block artifact
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                        for idx, t in enumerate(timesteps):
         | 
| 257 | 
            +
                            # expand the latents if we are doing classifier free guidance
         | 
| 258 | 
            +
                            latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
         | 
| 259 | 
            +
                        
         | 
| 260 | 
            +
                            # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
         | 
| 261 | 
            +
                            timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
         | 
| 262 | 
            +
                            
         | 
| 263 | 
            +
                            latent_model_input = past_conditions[i_s] + [latent_model_input]
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                            noise_pred = self.dit(
         | 
| 266 | 
            +
                                sample=[latent_model_input],
         | 
| 267 | 
            +
                                timestep_ratio=timestep,
         | 
| 268 | 
            +
                                encoder_hidden_states=prompt_embeds,
         | 
| 269 | 
            +
                                encoder_attention_mask=prompt_attention_mask,
         | 
| 270 | 
            +
                                pooled_projections=pooled_prompt_embeds,
         | 
| 271 | 
            +
                            )
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                            noise_pred = noise_pred[0]
         | 
| 274 | 
            +
                            
         | 
| 275 | 
            +
                            # perform guidance
         | 
| 276 | 
            +
                            if self.do_classifier_free_guidance:
         | 
| 277 | 
            +
                                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
         | 
| 278 | 
            +
                                if is_first_frame:
         | 
| 279 | 
            +
                                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
         | 
| 280 | 
            +
                                else:
         | 
| 281 | 
            +
                                    noise_pred = noise_pred_uncond + self.video_guidance_scale * (noise_pred_text - noise_pred_uncond)
         | 
| 282 | 
            +
                            
         | 
| 283 | 
            +
                            # compute the previous noisy sample x_t -> x_t-1
         | 
| 284 | 
            +
                            latents = self.scheduler.step(
         | 
| 285 | 
            +
                                model_output=noise_pred,
         | 
| 286 | 
            +
                                timestep=timestep,
         | 
| 287 | 
            +
                                sample=latents,
         | 
| 288 | 
            +
                                generator=generator,
         | 
| 289 | 
            +
                            ).prev_sample
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                        intermed_latents.append(latents)
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                    return intermed_latents
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                @torch.no_grad()
         | 
| 296 | 
            +
                def generate_i2v(
         | 
| 297 | 
            +
                    self,
         | 
| 298 | 
            +
                    prompt: Union[str, List[str]] = '',
         | 
| 299 | 
            +
                    input_image: PIL.Image = None,
         | 
| 300 | 
            +
                    temp: int = 1,
         | 
| 301 | 
            +
                    num_inference_steps: Optional[Union[int, List[int]]] = 28,
         | 
| 302 | 
            +
                    guidance_scale: float = 7.0,
         | 
| 303 | 
            +
                    video_guidance_scale: float = 4.0,
         | 
| 304 | 
            +
                    min_guidance_scale: float = 2.0,
         | 
| 305 | 
            +
                    use_linear_guidance: bool = False,
         | 
| 306 | 
            +
                    alpha: float = 0.5,
         | 
| 307 | 
            +
                    negative_prompt: Optional[Union[str, List[str]]]="cartoon style, worst quality, low quality, blurry, absolute black, absolute white, low res, extra limbs, extra digits, misplaced objects, mutated anatomy, monochrome, horror",
         | 
| 308 | 
            +
                    num_images_per_prompt: Optional[int] = 1,
         | 
| 309 | 
            +
                    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
         | 
| 310 | 
            +
                    output_type: Optional[str] = "pil",
         | 
| 311 | 
            +
                    save_memory: bool = True,
         | 
| 312 | 
            +
                ):
         | 
| 313 | 
            +
                    device = self.device
         | 
| 314 | 
            +
                    dtype = self.dtype
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                    width = input_image.width
         | 
| 317 | 
            +
                    height = input_image.height
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                    assert temp % self.frame_per_unit == 0, "The frames should be divided by frame_per unit"
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                    if isinstance(prompt, str):
         | 
| 322 | 
            +
                        batch_size = 1
         | 
| 323 | 
            +
                        prompt = prompt + ", hyper quality, Ultra HD, 8K"   # adding this prompt to improve aesthetics
         | 
| 324 | 
            +
                    else:
         | 
| 325 | 
            +
                        assert isinstance(prompt, list)
         | 
| 326 | 
            +
                        batch_size = len(prompt)
         | 
| 327 | 
            +
                        prompt = [_ + ", hyper quality, Ultra HD, 8K" for _ in prompt]
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                    if isinstance(num_inference_steps, int):
         | 
| 330 | 
            +
                        num_inference_steps = [num_inference_steps] * len(self.stages)
         | 
| 331 | 
            +
                    
         | 
| 332 | 
            +
                    negative_prompt = negative_prompt or ""
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                    # Get the text embeddings
         | 
| 335 | 
            +
                    prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.text_encoder(prompt, device)
         | 
| 336 | 
            +
                    negative_prompt_embeds, negative_prompt_attention_mask, negative_pooled_prompt_embeds = self.text_encoder(negative_prompt, device)
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                    if use_linear_guidance:
         | 
| 339 | 
            +
                        max_guidance_scale = guidance_scale
         | 
| 340 | 
            +
                        guidance_scale_list = [max(max_guidance_scale - alpha * t_, min_guidance_scale) for t_ in range(temp+1)]
         | 
| 341 | 
            +
                        print(guidance_scale_list)
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                    self._guidance_scale = guidance_scale
         | 
| 344 | 
            +
                    self._video_guidance_scale = video_guidance_scale
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                    if self.do_classifier_free_guidance:
         | 
| 347 | 
            +
                        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
         | 
| 348 | 
            +
                        pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
         | 
| 349 | 
            +
                        prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                    # Create the initial random noise
         | 
| 352 | 
            +
                    num_channels_latents = self.dit.config.in_channels
         | 
| 353 | 
            +
                    latents = self.prepare_latents(
         | 
| 354 | 
            +
                        batch_size * num_images_per_prompt,
         | 
| 355 | 
            +
                        num_channels_latents,
         | 
| 356 | 
            +
                        temp,
         | 
| 357 | 
            +
                        height,
         | 
| 358 | 
            +
                        width,
         | 
| 359 | 
            +
                        prompt_embeds.dtype,
         | 
| 360 | 
            +
                        device,
         | 
| 361 | 
            +
                        generator,
         | 
| 362 | 
            +
                    )
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                    temp, height, width = latents.shape[-3], latents.shape[-2], latents.shape[-1]
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                    latents = rearrange(latents, 'b c t h w -> (b t) c h w')
         | 
| 367 | 
            +
                    # by defalut, we needs to start from the block noise
         | 
| 368 | 
            +
                    for _ in range(len(self.stages)-1):
         | 
| 369 | 
            +
                        height //= 2;width //= 2
         | 
| 370 | 
            +
                        latents = F.interpolate(latents, size=(height, width), mode='bilinear') * 2
         | 
| 371 | 
            +
                    
         | 
| 372 | 
            +
                    latents = rearrange(latents, '(b t) c h w -> b c t h w', t=temp)
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                    num_units = temp // self.frame_per_unit
         | 
| 375 | 
            +
                    stages = self.stages
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                    # encode the image latents
         | 
| 378 | 
            +
                    image_transform = transforms.Compose([
         | 
| 379 | 
            +
                        transforms.ToTensor(),
         | 
| 380 | 
            +
                        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
         | 
| 381 | 
            +
                    ])
         | 
| 382 | 
            +
                    input_image_tensor = image_transform(input_image).unsqueeze(0).unsqueeze(2)   # [b c 1 h w]
         | 
| 383 | 
            +
                    input_image_latent = (self.vae.encode(input_image_tensor.to(device)).latent_dist.sample() - self.vae_shift_factor) * self.vae_scale_factor  # [b c 1 h w]
         | 
| 384 | 
            +
             | 
| 385 | 
            +
                    generated_latents_list = [input_image_latent]    # The generated results
         | 
| 386 | 
            +
                    last_generated_latents = input_image_latent
         | 
| 387 | 
            +
             | 
| 388 | 
            +
                    for unit_index in tqdm(range(1, num_units + 1)):
         | 
| 389 | 
            +
                        if use_linear_guidance:
         | 
| 390 | 
            +
                            self._guidance_scale = guidance_scale_list[unit_index]
         | 
| 391 | 
            +
                            self._video_guidance_scale = guidance_scale_list[unit_index]
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                        # prepare the condition latents
         | 
| 394 | 
            +
                        past_condition_latents = []
         | 
| 395 | 
            +
                        clean_latents_list = self.get_pyramid_latent(torch.cat(generated_latents_list, dim=2), len(stages) - 1)
         | 
| 396 | 
            +
                        
         | 
| 397 | 
            +
                        for i_s in range(len(stages)):
         | 
| 398 | 
            +
                            last_cond_latent = clean_latents_list[i_s][:,:,-self.frame_per_unit:]
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                            stage_input = [torch.cat([last_cond_latent] * 2) if self.do_classifier_free_guidance else last_cond_latent]
         | 
| 401 | 
            +
                    
         | 
| 402 | 
            +
                            # pad the past clean latents
         | 
| 403 | 
            +
                            cur_unit_num = unit_index
         | 
| 404 | 
            +
                            cur_stage = i_s
         | 
| 405 | 
            +
                            cur_unit_ptx = 1
         | 
| 406 | 
            +
             | 
| 407 | 
            +
                            while cur_unit_ptx < cur_unit_num:
         | 
| 408 | 
            +
                                cur_stage = max(cur_stage - 1, 0)
         | 
| 409 | 
            +
                                if cur_stage == 0:
         | 
| 410 | 
            +
                                    break
         | 
| 411 | 
            +
                                cur_unit_ptx += 1
         | 
| 412 | 
            +
                                cond_latents = clean_latents_list[cur_stage][:, :, -(cur_unit_ptx * self.frame_per_unit) : -((cur_unit_ptx - 1) * self.frame_per_unit)]
         | 
| 413 | 
            +
                                stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                            if cur_stage == 0 and cur_unit_ptx < cur_unit_num:
         | 
| 416 | 
            +
                                cond_latents = clean_latents_list[0][:, :, :-(cur_unit_ptx * self.frame_per_unit)]
         | 
| 417 | 
            +
                                stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
         | 
| 418 | 
            +
                        
         | 
| 419 | 
            +
                            stage_input = list(reversed(stage_input))
         | 
| 420 | 
            +
                            past_condition_latents.append(stage_input)
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                        intermed_latents = self.generate_one_unit(
         | 
| 423 | 
            +
                            latents[:,:,(unit_index - 1) * self.frame_per_unit:unit_index * self.frame_per_unit],
         | 
| 424 | 
            +
                            past_condition_latents,
         | 
| 425 | 
            +
                            prompt_embeds,
         | 
| 426 | 
            +
                            prompt_attention_mask,
         | 
| 427 | 
            +
                            pooled_prompt_embeds,
         | 
| 428 | 
            +
                            num_inference_steps,
         | 
| 429 | 
            +
                            height,
         | 
| 430 | 
            +
                            width,
         | 
| 431 | 
            +
                            self.frame_per_unit,
         | 
| 432 | 
            +
                            device,
         | 
| 433 | 
            +
                            dtype,
         | 
| 434 | 
            +
                            generator,
         | 
| 435 | 
            +
                            is_first_frame=False,
         | 
| 436 | 
            +
                        )
         | 
| 437 | 
            +
                
         | 
| 438 | 
            +
                        generated_latents_list.append(intermed_latents[-1])
         | 
| 439 | 
            +
                        last_generated_latents = intermed_latents
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                    generated_latents = torch.cat(generated_latents_list, dim=2)
         | 
| 442 | 
            +
             | 
| 443 | 
            +
                    if output_type == "latent":
         | 
| 444 | 
            +
                        image = generated_latents
         | 
| 445 | 
            +
                    else:
         | 
| 446 | 
            +
                        image = self.decode_latent(generated_latents, save_memory=save_memory)
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                    return image
         | 
| 449 | 
            +
             | 
| 450 | 
            +
                @torch.no_grad()
         | 
| 451 | 
            +
                def generate(
         | 
| 452 | 
            +
                    self,
         | 
| 453 | 
            +
                    prompt: Union[str, List[str]] = None,
         | 
| 454 | 
            +
                    height: Optional[int] = None,
         | 
| 455 | 
            +
                    width: Optional[int] = None,
         | 
| 456 | 
            +
                    temp: int = 1,
         | 
| 457 | 
            +
                    num_inference_steps: Optional[Union[int, List[int]]] = 28,
         | 
| 458 | 
            +
                    video_num_inference_steps: Optional[Union[int, List[int]]] = 28,
         | 
| 459 | 
            +
                    guidance_scale: float = 7.0,
         | 
| 460 | 
            +
                    video_guidance_scale: float = 7.0,
         | 
| 461 | 
            +
                    min_guidance_scale: float = 2.0,
         | 
| 462 | 
            +
                    use_linear_guidance: bool = False,
         | 
| 463 | 
            +
                    alpha: float = 0.5,
         | 
| 464 | 
            +
                    negative_prompt: Optional[Union[str, List[str]]]="cartoon style, worst quality, low quality, blurry, absolute black, absolute white, low res, extra limbs, extra digits, misplaced objects, mutated anatomy, monochrome, horror",
         | 
| 465 | 
            +
                    num_images_per_prompt: Optional[int] = 1,
         | 
| 466 | 
            +
                    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
         | 
| 467 | 
            +
                    output_type: Optional[str] = "pil",
         | 
| 468 | 
            +
                    save_memory: bool = True,
         | 
| 469 | 
            +
                ):
         | 
| 470 | 
            +
                    device = self.device
         | 
| 471 | 
            +
                    dtype = self.dtype
         | 
| 472 | 
            +
             | 
| 473 | 
            +
                    assert (temp - 1) % self.frame_per_unit == 0, "The frames should be divided by frame_per unit"
         | 
| 474 | 
            +
             | 
| 475 | 
            +
                    if isinstance(prompt, str):
         | 
| 476 | 
            +
                        batch_size = 1
         | 
| 477 | 
            +
                        prompt = prompt + ", hyper quality, Ultra HD, 8K"        # adding this prompt to improve aesthetics
         | 
| 478 | 
            +
                    else:
         | 
| 479 | 
            +
                        assert isinstance(prompt, list)
         | 
| 480 | 
            +
                        batch_size = len(prompt)
         | 
| 481 | 
            +
                        prompt = [_ + ", hyper quality, Ultra HD, 8K" for _ in prompt]
         | 
| 482 | 
            +
             | 
| 483 | 
            +
                    if isinstance(num_inference_steps, int):
         | 
| 484 | 
            +
                        num_inference_steps = [num_inference_steps] * len(self.stages)
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                    if isinstance(video_num_inference_steps, int):
         | 
| 487 | 
            +
                        video_num_inference_steps = [video_num_inference_steps] * len(self.stages)
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                    negative_prompt = negative_prompt or ""
         | 
| 490 | 
            +
             | 
| 491 | 
            +
                    # Get the text embeddings
         | 
| 492 | 
            +
                    prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.text_encoder(prompt, device)
         | 
| 493 | 
            +
                    negative_prompt_embeds, negative_prompt_attention_mask, negative_pooled_prompt_embeds = self.text_encoder(negative_prompt, device)
         | 
| 494 | 
            +
             | 
| 495 | 
            +
                    if use_linear_guidance:
         | 
| 496 | 
            +
                        max_guidance_scale = guidance_scale
         | 
| 497 | 
            +
                        # guidance_scale_list = torch.linspace(max_guidance_scale, min_guidance_scale, temp).tolist()
         | 
| 498 | 
            +
                        guidance_scale_list = [max(max_guidance_scale - alpha * t_, min_guidance_scale) for t_ in range(temp)]
         | 
| 499 | 
            +
                        print(guidance_scale_list)
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                    self._guidance_scale = guidance_scale
         | 
| 502 | 
            +
                    self._video_guidance_scale = video_guidance_scale
         | 
| 503 | 
            +
             | 
| 504 | 
            +
                    if self.do_classifier_free_guidance:
         | 
| 505 | 
            +
                        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
         | 
| 506 | 
            +
                        pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
         | 
| 507 | 
            +
                        prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                    # Create the initial random noise
         | 
| 510 | 
            +
                    num_channels_latents = self.dit.config.in_channels
         | 
| 511 | 
            +
                    latents = self.prepare_latents(
         | 
| 512 | 
            +
                        batch_size * num_images_per_prompt,
         | 
| 513 | 
            +
                        num_channels_latents,
         | 
| 514 | 
            +
                        temp,
         | 
| 515 | 
            +
                        height,
         | 
| 516 | 
            +
                        width,
         | 
| 517 | 
            +
                        prompt_embeds.dtype,
         | 
| 518 | 
            +
                        device,
         | 
| 519 | 
            +
                        generator,
         | 
| 520 | 
            +
                    )
         | 
| 521 | 
            +
             | 
| 522 | 
            +
                    temp, height, width = latents.shape[-3], latents.shape[-2], latents.shape[-1]
         | 
| 523 | 
            +
             | 
| 524 | 
            +
                    latents = rearrange(latents, 'b c t h w -> (b t) c h w')
         | 
| 525 | 
            +
                    # by defalut, we needs to start from the block noise
         | 
| 526 | 
            +
                    for _ in range(len(self.stages)-1):
         | 
| 527 | 
            +
                        height //= 2;width //= 2
         | 
| 528 | 
            +
                        latents = F.interpolate(latents, size=(height, width), mode='bilinear') * 2
         | 
| 529 | 
            +
                    
         | 
| 530 | 
            +
                    latents = rearrange(latents, '(b t) c h w -> b c t h w', t=temp)
         | 
| 531 | 
            +
             | 
| 532 | 
            +
                    num_units = 1 + (temp - 1) // self.frame_per_unit
         | 
| 533 | 
            +
                    stages = self.stages
         | 
| 534 | 
            +
             | 
| 535 | 
            +
                    generated_latents_list = []    # The generated results
         | 
| 536 | 
            +
                    last_generated_latents = None
         | 
| 537 | 
            +
             | 
| 538 | 
            +
                    for unit_index in tqdm(range(num_units)):
         | 
| 539 | 
            +
                        if use_linear_guidance:
         | 
| 540 | 
            +
                            self._guidance_scale = guidance_scale_list[unit_index]
         | 
| 541 | 
            +
                            self._video_guidance_scale = guidance_scale_list[unit_index]
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                        if unit_index == 0:
         | 
| 544 | 
            +
                            past_condition_latents = [[] for _ in range(len(stages))]
         | 
| 545 | 
            +
                            intermed_latents = self.generate_one_unit(
         | 
| 546 | 
            +
                                latents[:,:,:1],
         | 
| 547 | 
            +
                                past_condition_latents,
         | 
| 548 | 
            +
                                prompt_embeds,
         | 
| 549 | 
            +
                                prompt_attention_mask,
         | 
| 550 | 
            +
                                pooled_prompt_embeds,
         | 
| 551 | 
            +
                                num_inference_steps,
         | 
| 552 | 
            +
                                height,
         | 
| 553 | 
            +
                                width,
         | 
| 554 | 
            +
                                1,
         | 
| 555 | 
            +
                                device,
         | 
| 556 | 
            +
                                dtype,
         | 
| 557 | 
            +
                                generator,
         | 
| 558 | 
            +
                                is_first_frame=True,
         | 
| 559 | 
            +
                            )
         | 
| 560 | 
            +
                        else:
         | 
| 561 | 
            +
                            # prepare the condition latents
         | 
| 562 | 
            +
                            past_condition_latents = []
         | 
| 563 | 
            +
                            clean_latents_list = self.get_pyramid_latent(torch.cat(generated_latents_list, dim=2), len(stages) - 1)
         | 
| 564 | 
            +
                            
         | 
| 565 | 
            +
                            for i_s in range(len(stages)):
         | 
| 566 | 
            +
                                last_cond_latent = clean_latents_list[i_s][:,:,-(self.frame_per_unit):]
         | 
| 567 | 
            +
             | 
| 568 | 
            +
                                stage_input = [torch.cat([last_cond_latent] * 2) if self.do_classifier_free_guidance else last_cond_latent]
         | 
| 569 | 
            +
                        
         | 
| 570 | 
            +
                                # pad the past clean latents
         | 
| 571 | 
            +
                                cur_unit_num = unit_index
         | 
| 572 | 
            +
                                cur_stage = i_s
         | 
| 573 | 
            +
                                cur_unit_ptx = 1
         | 
| 574 | 
            +
             | 
| 575 | 
            +
                                while cur_unit_ptx < cur_unit_num:
         | 
| 576 | 
            +
                                    cur_stage = max(cur_stage - 1, 0)
         | 
| 577 | 
            +
                                    if cur_stage == 0:
         | 
| 578 | 
            +
                                        break
         | 
| 579 | 
            +
                                    cur_unit_ptx += 1
         | 
| 580 | 
            +
                                    cond_latents = clean_latents_list[cur_stage][:, :, -(cur_unit_ptx * self.frame_per_unit) : -((cur_unit_ptx - 1) * self.frame_per_unit)]
         | 
| 581 | 
            +
                                    stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
         | 
| 582 | 
            +
             | 
| 583 | 
            +
                                if cur_stage == 0 and cur_unit_ptx < cur_unit_num:
         | 
| 584 | 
            +
                                    cond_latents = clean_latents_list[0][:, :, :-(cur_unit_ptx * self.frame_per_unit)]
         | 
| 585 | 
            +
                                    stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
         | 
| 586 | 
            +
                            
         | 
| 587 | 
            +
                                stage_input = list(reversed(stage_input))
         | 
| 588 | 
            +
                                past_condition_latents.append(stage_input)
         | 
| 589 | 
            +
             | 
| 590 | 
            +
                            intermed_latents = self.generate_one_unit(
         | 
| 591 | 
            +
                                latents[:,:, 1 + (unit_index - 1) * self.frame_per_unit:1 + unit_index * self.frame_per_unit],
         | 
| 592 | 
            +
                                past_condition_latents,
         | 
| 593 | 
            +
                                prompt_embeds,
         | 
| 594 | 
            +
                                prompt_attention_mask,
         | 
| 595 | 
            +
                                pooled_prompt_embeds,
         | 
| 596 | 
            +
                                video_num_inference_steps,
         | 
| 597 | 
            +
                                height,
         | 
| 598 | 
            +
                                width,
         | 
| 599 | 
            +
                                self.frame_per_unit,
         | 
| 600 | 
            +
                                device,
         | 
| 601 | 
            +
                                dtype,
         | 
| 602 | 
            +
                                generator,
         | 
| 603 | 
            +
                                is_first_frame=False,
         | 
| 604 | 
            +
                            )
         | 
| 605 | 
            +
             | 
| 606 | 
            +
                        generated_latents_list.append(intermed_latents[-1])
         | 
| 607 | 
            +
                        last_generated_latents = intermed_latents
         | 
| 608 | 
            +
             | 
| 609 | 
            +
                    generated_latents = torch.cat(generated_latents_list, dim=2)
         | 
| 610 | 
            +
             | 
| 611 | 
            +
                    if output_type == "latent":
         | 
| 612 | 
            +
                        image = generated_latents
         | 
| 613 | 
            +
                    else:
         | 
| 614 | 
            +
                        image = self.decode_latent(generated_latents, save_memory=save_memory)
         | 
| 615 | 
            +
             | 
| 616 | 
            +
                    return image
         | 
| 617 | 
            +
             | 
| 618 | 
            +
                def decode_latent(self, latents, save_memory=True):
         | 
| 619 | 
            +
                    if latents.shape[2] == 1:
         | 
| 620 | 
            +
                        latents = (latents / self.vae_scale_factor) + self.vae_shift_factor
         | 
| 621 | 
            +
                    else:
         | 
| 622 | 
            +
                        latents[:, :, :1] = (latents[:, :, :1] / self.vae_scale_factor) + self.vae_shift_factor
         | 
| 623 | 
            +
                        latents[:, :, 1:] = (latents[:, :, 1:] / self.vae_video_scale_factor) + self.vae_video_shift_factor
         | 
| 624 | 
            +
             | 
| 625 | 
            +
                    if save_memory:
         | 
| 626 | 
            +
                        # reducing the tile size and temporal chunk window size
         | 
| 627 | 
            +
                        image = self.vae.decode(latents, temporal_chunk=True, window_size=1, tile_sample_min_size=256).sample
         | 
| 628 | 
            +
                    else:
         | 
| 629 | 
            +
                        image = self.vae.decode(latents, temporal_chunk=True, window_size=2, tile_sample_min_size=512).sample
         | 
| 630 | 
            +
             | 
| 631 | 
            +
                    image = image.float()
         | 
| 632 | 
            +
                    image = (image / 2 + 0.5).clamp(0, 1)
         | 
| 633 | 
            +
                    image = rearrange(image, "B C T H W -> (B T) C H W")
         | 
| 634 | 
            +
                    image = image.cpu().permute(0, 2, 3, 1).numpy()
         | 
| 635 | 
            +
                    image = self.numpy_to_pil(image)
         | 
| 636 | 
            +
                    return image
         | 
| 637 | 
            +
             | 
| 638 | 
            +
                @staticmethod
         | 
| 639 | 
            +
                def numpy_to_pil(images):
         | 
| 640 | 
            +
                    """
         | 
| 641 | 
            +
                    Convert a numpy image or a batch of images to a PIL image.
         | 
| 642 | 
            +
                    """
         | 
| 643 | 
            +
                    if images.ndim == 3:
         | 
| 644 | 
            +
                        images = images[None, ...]
         | 
| 645 | 
            +
                    images = (images * 255).round().astype("uint8")
         | 
| 646 | 
            +
                    if images.shape[-1] == 1:
         | 
| 647 | 
            +
                        # special case for grayscale (single channel) images
         | 
| 648 | 
            +
                        pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
         | 
| 649 | 
            +
                    else:
         | 
| 650 | 
            +
                        pil_images = [Image.fromarray(image) for image in images]
         | 
| 651 | 
            +
             | 
| 652 | 
            +
                    return pil_images
         | 
| 653 | 
            +
             | 
| 654 | 
            +
                @property
         | 
| 655 | 
            +
                def device(self):
         | 
| 656 | 
            +
                    return next(self.dit.parameters()).device
         | 
| 657 | 
            +
             | 
| 658 | 
            +
                @property
         | 
| 659 | 
            +
                def dtype(self):
         | 
| 660 | 
            +
                    return next(self.dit.parameters()).dtype
         | 
| 661 | 
            +
             | 
| 662 | 
            +
                @property
         | 
| 663 | 
            +
                def guidance_scale(self):
         | 
| 664 | 
            +
                    return self._guidance_scale
         | 
| 665 | 
            +
             | 
| 666 | 
            +
                @property
         | 
| 667 | 
            +
                def video_guidance_scale(self):
         | 
| 668 | 
            +
                    return self._video_guidance_scale
         | 
| 669 | 
            +
             | 
| 670 | 
            +
                @property
         | 
| 671 | 
            +
                def do_classifier_free_guidance(self):
         | 
| 672 | 
            +
                    return self._guidance_scale > 0
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,32 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            contexttimer
         | 
| 2 | 
            +
            decord
         | 
| 3 | 
            +
            diffusers>=0.30.1
         | 
| 4 | 
            +
            accelerate==0.30.0
         | 
| 5 | 
            +
            torch==2.1.2
         | 
| 6 | 
            +
            torchvision==0.16.2
         | 
| 7 | 
            +
            numpy==1.24.4
         | 
| 8 | 
            +
            einops
         | 
| 9 | 
            +
            ftfy
         | 
| 10 | 
            +
            ipython
         | 
| 11 | 
            +
            opencv-python-headless==4.10.0.84
         | 
| 12 | 
            +
            imageio==2.33.1
         | 
| 13 | 
            +
            imageio-ffmpeg==0.5.1
         | 
| 14 | 
            +
            packaging
         | 
| 15 | 
            +
            pandas
         | 
| 16 | 
            +
            plotly
         | 
| 17 | 
            +
            pre-commit
         | 
| 18 | 
            +
            pycocoevalcap
         | 
| 19 | 
            +
            pycocotools
         | 
| 20 | 
            +
            python-magic
         | 
| 21 | 
            +
            scikit-image
         | 
| 22 | 
            +
            sentencepiece
         | 
| 23 | 
            +
            spacy
         | 
| 24 | 
            +
            streamlit
         | 
| 25 | 
            +
            timm==0.6.12
         | 
| 26 | 
            +
            tqdm
         | 
| 27 | 
            +
            transformers==4.39.3
         | 
| 28 | 
            +
            wheel
         | 
| 29 | 
            +
            torchmetrics
         | 
| 30 | 
            +
            tiktoken
         | 
| 31 | 
            +
            jsonlines
         | 
| 32 | 
            +
            tensorboardX
         | 
    	
        trainer_misc/__init__.py
    ADDED
    
    | @@ -0,0 +1,25 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .utils import (
         | 
| 2 | 
            +
                create_optimizer,
         | 
| 3 | 
            +
                get_rank,
         | 
| 4 | 
            +
                get_world_size,
         | 
| 5 | 
            +
                is_main_process,
         | 
| 6 | 
            +
                is_dist_avail_and_initialized,
         | 
| 7 | 
            +
                init_distributed_mode, 
         | 
| 8 | 
            +
                setup_for_distributed, 
         | 
| 9 | 
            +
                cosine_scheduler,
         | 
| 10 | 
            +
                constant_scheduler,
         | 
| 11 | 
            +
            )
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from .sp_utils import (
         | 
| 14 | 
            +
                is_sequence_parallel_initialized,
         | 
| 15 | 
            +
                init_sequence_parallel_group,
         | 
| 16 | 
            +
                get_sequence_parallel_group,
         | 
| 17 | 
            +
                get_sequence_parallel_world_size,
         | 
| 18 | 
            +
                get_sequence_parallel_rank,
         | 
| 19 | 
            +
                get_sequence_parallel_group_rank,
         | 
| 20 | 
            +
                get_sequence_parallel_proc_num,
         | 
| 21 | 
            +
                init_sync_input_group,
         | 
| 22 | 
            +
                get_sync_input_group,
         | 
| 23 | 
            +
            )
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            from .communicate import all_to_all
         | 
    	
        trainer_misc/communicate.py
    ADDED
    
    | @@ -0,0 +1,58 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import math
         | 
| 4 | 
            +
            import torch.distributed as dist
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            def _all_to_all(
         | 
| 8 | 
            +
                input_: torch.Tensor,
         | 
| 9 | 
            +
                world_size: int,
         | 
| 10 | 
            +
                group: dist.ProcessGroup,
         | 
| 11 | 
            +
                scatter_dim: int,
         | 
| 12 | 
            +
                gather_dim: int,
         | 
| 13 | 
            +
            ):
         | 
| 14 | 
            +
                if world_size == 1:
         | 
| 15 | 
            +
                    return input_
         | 
| 16 | 
            +
                input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
         | 
| 17 | 
            +
                output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
         | 
| 18 | 
            +
                dist.all_to_all(output_list, input_list, group=group)
         | 
| 19 | 
            +
                return torch.cat(output_list, dim=gather_dim).contiguous()
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class _AllToAll(torch.autograd.Function):
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                @staticmethod
         | 
| 25 | 
            +
                def forward(ctx, input_, process_group, world_size, scatter_dim, gather_dim):
         | 
| 26 | 
            +
                    ctx.process_group = process_group
         | 
| 27 | 
            +
                    ctx.scatter_dim = scatter_dim
         | 
| 28 | 
            +
                    ctx.gather_dim = gather_dim
         | 
| 29 | 
            +
                    ctx.world_size = world_size
         | 
| 30 | 
            +
                    output = _all_to_all(input_, ctx.world_size, process_group, scatter_dim, gather_dim)
         | 
| 31 | 
            +
                    return output
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                @staticmethod
         | 
| 34 | 
            +
                def backward(ctx, grad_output):
         | 
| 35 | 
            +
                    grad_output = _all_to_all(
         | 
| 36 | 
            +
                        grad_output,
         | 
| 37 | 
            +
                        ctx.world_size,
         | 
| 38 | 
            +
                        ctx.process_group,
         | 
| 39 | 
            +
                        ctx.gather_dim,
         | 
| 40 | 
            +
                        ctx.scatter_dim,
         | 
| 41 | 
            +
                    )
         | 
| 42 | 
            +
                    return (
         | 
| 43 | 
            +
                        grad_output,
         | 
| 44 | 
            +
                        None,
         | 
| 45 | 
            +
                        None,
         | 
| 46 | 
            +
                        None,
         | 
| 47 | 
            +
                        None,
         | 
| 48 | 
            +
                    )
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            def all_to_all(
         | 
| 52 | 
            +
                input_: torch.Tensor,
         | 
| 53 | 
            +
                process_group: dist.ProcessGroup,
         | 
| 54 | 
            +
                world_size: int = 1,
         | 
| 55 | 
            +
                scatter_dim: int = 2,
         | 
| 56 | 
            +
                gather_dim: int = 1,
         | 
| 57 | 
            +
            ):
         | 
| 58 | 
            +
                return _AllToAll.apply(input_, process_group, world_size, scatter_dim, gather_dim)
         | 
    	
        trainer_misc/sp_utils.py
    ADDED
    
    | @@ -0,0 +1,98 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.distributed as dist
         | 
| 4 | 
            +
            from .utils import is_dist_avail_and_initialized, get_rank
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            SEQ_PARALLEL_GROUP = None
         | 
| 8 | 
            +
            SEQ_PARALLEL_SIZE = None
         | 
| 9 | 
            +
            SEQ_PARALLEL_PROC_NUM = None    # using how many process for sequence parallel
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            SYNC_INPUT_GROUP = None
         | 
| 12 | 
            +
            SYNC_INPUT_SIZE = None
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            def is_sequence_parallel_initialized():
         | 
| 15 | 
            +
                if SEQ_PARALLEL_GROUP is None:
         | 
| 16 | 
            +
                    return False
         | 
| 17 | 
            +
                else:
         | 
| 18 | 
            +
                    return True
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            def init_sequence_parallel_group(args):
         | 
| 22 | 
            +
                global SEQ_PARALLEL_GROUP
         | 
| 23 | 
            +
                global SEQ_PARALLEL_SIZE
         | 
| 24 | 
            +
                global SEQ_PARALLEL_PROC_NUM
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                assert SEQ_PARALLEL_GROUP is None, "sequence parallel group is already initialized"
         | 
| 27 | 
            +
                assert is_dist_avail_and_initialized(), "The pytorch distributed should be initialized"
         | 
| 28 | 
            +
                SEQ_PARALLEL_SIZE = args.sp_group_size
         | 
| 29 | 
            +
                
         | 
| 30 | 
            +
                print(f"Setting the Sequence Parallel Size {SEQ_PARALLEL_SIZE}")
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                rank = torch.distributed.get_rank()
         | 
| 33 | 
            +
                world_size = torch.distributed.get_world_size()
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                if args.sp_proc_num == -1:
         | 
| 36 | 
            +
                    SEQ_PARALLEL_PROC_NUM = world_size
         | 
| 37 | 
            +
                else:
         | 
| 38 | 
            +
                    SEQ_PARALLEL_PROC_NUM = args.sp_proc_num
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                assert SEQ_PARALLEL_PROC_NUM % SEQ_PARALLEL_SIZE == 0, "The process needs to be evenly divided"
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                for i in range(0, SEQ_PARALLEL_PROC_NUM, SEQ_PARALLEL_SIZE):
         | 
| 43 | 
            +
                    ranks = list(range(i, i + SEQ_PARALLEL_SIZE))
         | 
| 44 | 
            +
                    group = torch.distributed.new_group(ranks)
         | 
| 45 | 
            +
                    if rank in ranks:
         | 
| 46 | 
            +
                        SEQ_PARALLEL_GROUP = group
         | 
| 47 | 
            +
                        break
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            def init_sync_input_group(args):
         | 
| 51 | 
            +
                global SYNC_INPUT_GROUP
         | 
| 52 | 
            +
                global SYNC_INPUT_SIZE
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                assert SYNC_INPUT_GROUP is None, "parallel group is already initialized"
         | 
| 55 | 
            +
                assert is_dist_avail_and_initialized(), "The pytorch distributed should be initialized"
         | 
| 56 | 
            +
                SYNC_INPUT_SIZE = args.max_frames
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                rank = torch.distributed.get_rank()
         | 
| 59 | 
            +
                world_size = torch.distributed.get_world_size()
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                for i in range(0, world_size, SYNC_INPUT_SIZE):
         | 
| 62 | 
            +
                    ranks = list(range(i, i + SYNC_INPUT_SIZE))
         | 
| 63 | 
            +
                    group = torch.distributed.new_group(ranks)
         | 
| 64 | 
            +
                    if rank in ranks:
         | 
| 65 | 
            +
                        SYNC_INPUT_GROUP = group
         | 
| 66 | 
            +
                        break
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            def get_sequence_parallel_group():
         | 
| 70 | 
            +
                assert SEQ_PARALLEL_GROUP is not None, "sequence parallel group is not initialized"
         | 
| 71 | 
            +
                return SEQ_PARALLEL_GROUP
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            def get_sync_input_group():
         | 
| 75 | 
            +
                return SYNC_INPUT_GROUP
         | 
| 76 | 
            +
             | 
| 77 | 
            +
             | 
| 78 | 
            +
            def get_sequence_parallel_world_size():
         | 
| 79 | 
            +
                assert SEQ_PARALLEL_SIZE is not None, "sequence parallel size is not initialized"
         | 
| 80 | 
            +
                return SEQ_PARALLEL_SIZE
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            def get_sequence_parallel_rank():
         | 
| 84 | 
            +
                assert SEQ_PARALLEL_SIZE is not None, "sequence parallel size is not initialized"
         | 
| 85 | 
            +
                rank = get_rank()
         | 
| 86 | 
            +
                cp_rank = rank % SEQ_PARALLEL_SIZE
         | 
| 87 | 
            +
                return cp_rank
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
            def get_sequence_parallel_group_rank():
         | 
| 91 | 
            +
                assert SEQ_PARALLEL_SIZE is not None, "sequence parallel size is not initialized"
         | 
| 92 | 
            +
                rank = get_rank()
         | 
| 93 | 
            +
                cp_group_rank = rank // SEQ_PARALLEL_SIZE
         | 
| 94 | 
            +
                return cp_group_rank
         | 
| 95 | 
            +
             | 
| 96 | 
            +
             | 
| 97 | 
            +
            def get_sequence_parallel_proc_num():
         | 
| 98 | 
            +
                return SEQ_PARALLEL_PROC_NUM
         | 
    	
        trainer_misc/utils.py
    ADDED
    
    | @@ -0,0 +1,382 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import io
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import math
         | 
| 4 | 
            +
            import time
         | 
| 5 | 
            +
            import json
         | 
| 6 | 
            +
            import glob
         | 
| 7 | 
            +
            from collections import defaultdict, deque, OrderedDict
         | 
| 8 | 
            +
            import datetime
         | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            from pathlib import Path
         | 
| 13 | 
            +
            import argparse
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            import torch
         | 
| 16 | 
            +
            from torch import optim as optim
         | 
| 17 | 
            +
            import torch.distributed as dist
         | 
| 18 | 
            +
            from tensorboardX import SummaryWriter
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            def is_dist_avail_and_initialized():
         | 
| 22 | 
            +
                if not dist.is_available():
         | 
| 23 | 
            +
                    return False
         | 
| 24 | 
            +
                if not dist.is_initialized():
         | 
| 25 | 
            +
                    return False
         | 
| 26 | 
            +
                return True
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            def get_world_size():
         | 
| 30 | 
            +
                if not is_dist_avail_and_initialized():
         | 
| 31 | 
            +
                    return 1
         | 
| 32 | 
            +
                return dist.get_world_size()
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            def get_rank():
         | 
| 36 | 
            +
                if not is_dist_avail_and_initialized():
         | 
| 37 | 
            +
                    return 0
         | 
| 38 | 
            +
                return dist.get_rank()
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            def is_main_process():
         | 
| 42 | 
            +
                return get_rank() == 0
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            def save_on_master(*args, **kwargs):
         | 
| 46 | 
            +
                if is_main_process():
         | 
| 47 | 
            +
                    torch.save(*args, **kwargs)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            def setup_for_distributed(is_master):
         | 
| 51 | 
            +
                """
         | 
| 52 | 
            +
                This function disables printing when not in master process
         | 
| 53 | 
            +
                """
         | 
| 54 | 
            +
                import builtins as __builtin__
         | 
| 55 | 
            +
                builtin_print = __builtin__.print
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def print(*args, **kwargs):
         | 
| 58 | 
            +
                    force = kwargs.pop('force', False)
         | 
| 59 | 
            +
                    if is_master or force:
         | 
| 60 | 
            +
                        builtin_print(*args, **kwargs)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                __builtin__.print = print
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            def init_distributed_mode(args):
         | 
| 66 | 
            +
                if int(os.getenv('OMPI_COMM_WORLD_SIZE', '0')) > 0:
         | 
| 67 | 
            +
                    rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
         | 
| 68 | 
            +
                    local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
         | 
| 69 | 
            +
                    world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    os.environ["LOCAL_RANK"] = os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
         | 
| 72 | 
            +
                    os.environ["RANK"] = os.environ['OMPI_COMM_WORLD_RANK']
         | 
| 73 | 
            +
                    os.environ["WORLD_SIZE"] = os.environ['OMPI_COMM_WORLD_SIZE']
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    args.rank = int(os.environ["RANK"])
         | 
| 76 | 
            +
                    args.world_size = int(os.environ["WORLD_SIZE"])
         | 
| 77 | 
            +
                    args.gpu = int(os.environ["LOCAL_RANK"])
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
         | 
| 80 | 
            +
                    args.rank = int(os.environ["RANK"])
         | 
| 81 | 
            +
                    args.world_size = int(os.environ['WORLD_SIZE'])
         | 
| 82 | 
            +
                    args.gpu = int(os.environ['LOCAL_RANK'])
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                else:
         | 
| 85 | 
            +
                    print('Not using distributed mode')
         | 
| 86 | 
            +
                    args.distributed = False
         | 
| 87 | 
            +
                    return
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                args.distributed = True
         | 
| 90 | 
            +
                args.dist_backend = 'nccl'
         | 
| 91 | 
            +
                args.dist_url = "env://"
         | 
| 92 | 
            +
                print('| distributed init (rank {}): {}, gpu {}'.format(
         | 
| 93 | 
            +
                    args.rank, args.dist_url, args.gpu), flush=True)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
             | 
| 96 | 
            +
            def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, 
         | 
| 97 | 
            +
                    start_warmup_value=0, warmup_steps=-1):
         | 
| 98 | 
            +
                warmup_schedule = np.array([])
         | 
| 99 | 
            +
                warmup_iters = warmup_epochs * niter_per_ep
         | 
| 100 | 
            +
                if warmup_steps > 0:
         | 
| 101 | 
            +
                    warmup_iters = warmup_steps
         | 
| 102 | 
            +
                print("Set warmup steps = %d" % warmup_iters)
         | 
| 103 | 
            +
                if warmup_epochs > 0:
         | 
| 104 | 
            +
                    warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                iters = np.arange(epochs * niter_per_ep - warmup_iters)
         | 
| 107 | 
            +
                schedule = np.array(
         | 
| 108 | 
            +
                    [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                schedule = np.concatenate((warmup_schedule, schedule))
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                assert len(schedule) == epochs * niter_per_ep
         | 
| 113 | 
            +
                return schedule
         | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
            def constant_scheduler(base_value, epochs, niter_per_ep, warmup_epochs=0, 
         | 
| 117 | 
            +
                    start_warmup_value=1e-6, warmup_steps=-1):
         | 
| 118 | 
            +
                warmup_schedule = np.array([])
         | 
| 119 | 
            +
                warmup_iters = warmup_epochs * niter_per_ep
         | 
| 120 | 
            +
                if warmup_steps > 0:
         | 
| 121 | 
            +
                    warmup_iters = warmup_steps
         | 
| 122 | 
            +
                print("Set warmup steps = %d" % warmup_iters)
         | 
| 123 | 
            +
                if warmup_iters > 0:
         | 
| 124 | 
            +
                    warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                iters = epochs * niter_per_ep - warmup_iters
         | 
| 127 | 
            +
                schedule = np.array([base_value] * iters)
         | 
| 128 | 
            +
                
         | 
| 129 | 
            +
                schedule = np.concatenate((warmup_schedule, schedule))
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                assert len(schedule) == epochs * niter_per_ep
         | 
| 132 | 
            +
                return schedule
         | 
| 133 | 
            +
             | 
| 134 | 
            +
             | 
| 135 | 
            +
            def get_parameter_groups(model, weight_decay=1e-5, base_lr=1e-4, skip_list=(), get_num_layer=None, get_layer_scale=None, **kwargs):
         | 
| 136 | 
            +
                parameter_group_names = {}
         | 
| 137 | 
            +
                parameter_group_vars = {}
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                for name, param in model.named_parameters():
         | 
| 140 | 
            +
                    if not param.requires_grad:
         | 
| 141 | 
            +
                        continue  # frozen weights
         | 
| 142 | 
            +
                    if len(kwargs.get('filter_name', [])) > 0:
         | 
| 143 | 
            +
                        flag = False
         | 
| 144 | 
            +
                        for filter_n in kwargs.get('filter_name', []):
         | 
| 145 | 
            +
                            if filter_n in name:
         | 
| 146 | 
            +
                                print(f"filter {name} because of the pattern {filter_n}")
         | 
| 147 | 
            +
                                flag = True
         | 
| 148 | 
            +
                        if flag:
         | 
| 149 | 
            +
                            continue
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    default_scale=1.
         | 
| 152 | 
            +
                    
         | 
| 153 | 
            +
                    if param.ndim <= 1 or name.endswith(".bias") or name in skip_list: # param.ndim <= 1 len(param.shape) == 1
         | 
| 154 | 
            +
                        group_name = "no_decay"
         | 
| 155 | 
            +
                        this_weight_decay = 0.
         | 
| 156 | 
            +
                    else:
         | 
| 157 | 
            +
                        group_name = "decay"
         | 
| 158 | 
            +
                        this_weight_decay = weight_decay
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    if get_num_layer is not None:
         | 
| 161 | 
            +
                        layer_id = get_num_layer(name)
         | 
| 162 | 
            +
                        group_name = "layer_%d_%s" % (layer_id, group_name)
         | 
| 163 | 
            +
                    else:
         | 
| 164 | 
            +
                        layer_id = None
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    if group_name not in parameter_group_names:
         | 
| 167 | 
            +
                        if get_layer_scale is not None:
         | 
| 168 | 
            +
                            scale = get_layer_scale(layer_id)
         | 
| 169 | 
            +
                        else:
         | 
| 170 | 
            +
                            scale = default_scale
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                        parameter_group_names[group_name] = {
         | 
| 173 | 
            +
                            "weight_decay": this_weight_decay,
         | 
| 174 | 
            +
                            "params": [],
         | 
| 175 | 
            +
                            "lr": base_lr,
         | 
| 176 | 
            +
                            "lr_scale": scale,
         | 
| 177 | 
            +
                        }
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                        parameter_group_vars[group_name] = {
         | 
| 180 | 
            +
                            "weight_decay": this_weight_decay,
         | 
| 181 | 
            +
                            "params": [],
         | 
| 182 | 
            +
                            "lr": base_lr,
         | 
| 183 | 
            +
                            "lr_scale": scale,
         | 
| 184 | 
            +
                        }
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    parameter_group_vars[group_name]["params"].append(param)
         | 
| 187 | 
            +
                    parameter_group_names[group_name]["params"].append(name)
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
         | 
| 190 | 
            +
                return list(parameter_group_vars.values())
         | 
| 191 | 
            +
             | 
| 192 | 
            +
             | 
| 193 | 
            +
            def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None, **kwargs):
         | 
| 194 | 
            +
                opt_lower = args.opt.lower()
         | 
| 195 | 
            +
                weight_decay = args.weight_decay
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                skip = {}
         | 
| 198 | 
            +
                if skip_list is not None:
         | 
| 199 | 
            +
                    skip = skip_list
         | 
| 200 | 
            +
                elif hasattr(model, 'no_weight_decay'):
         | 
| 201 | 
            +
                    skip = model.no_weight_decay()
         | 
| 202 | 
            +
                print(f"Skip weight decay name marked in model: {skip}")
         | 
| 203 | 
            +
                parameters = get_parameter_groups(model, weight_decay, args.lr, skip, get_num_layer, get_layer_scale, **kwargs)
         | 
| 204 | 
            +
                weight_decay = 0.
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                if 'fused' in opt_lower:
         | 
| 207 | 
            +
                    assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                opt_args = dict(lr=args.lr, weight_decay=weight_decay)
         | 
| 210 | 
            +
                if hasattr(args, 'opt_eps') and args.opt_eps is not None:
         | 
| 211 | 
            +
                    opt_args['eps'] = args.opt_eps
         | 
| 212 | 
            +
                if hasattr(args, 'opt_beta1') and args.opt_beta1 is not None:
         | 
| 213 | 
            +
                    opt_args['betas'] = (args.opt_beta1, args.opt_beta2)
         | 
| 214 | 
            +
                
         | 
| 215 | 
            +
                print('Optimizer config:', opt_args)
         | 
| 216 | 
            +
                opt_split = opt_lower.split('_')
         | 
| 217 | 
            +
                opt_lower = opt_split[-1]
         | 
| 218 | 
            +
                if opt_lower == 'sgd' or opt_lower == 'nesterov':
         | 
| 219 | 
            +
                    opt_args.pop('eps', None)
         | 
| 220 | 
            +
                    optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
         | 
| 221 | 
            +
                elif opt_lower == 'momentum':
         | 
| 222 | 
            +
                    opt_args.pop('eps', None)
         | 
| 223 | 
            +
                    optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
         | 
| 224 | 
            +
                elif opt_lower == 'adam':
         | 
| 225 | 
            +
                    optimizer = optim.Adam(parameters, **opt_args)
         | 
| 226 | 
            +
                elif opt_lower == 'adamw':
         | 
| 227 | 
            +
                    optimizer = optim.AdamW(parameters, **opt_args)
         | 
| 228 | 
            +
                elif opt_lower == 'adadelta':
         | 
| 229 | 
            +
                    optimizer = optim.Adadelta(parameters, **opt_args)
         | 
| 230 | 
            +
                elif opt_lower == 'rmsprop':
         | 
| 231 | 
            +
                    optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
         | 
| 232 | 
            +
                else:
         | 
| 233 | 
            +
                    assert False and "Invalid optimizer"
         | 
| 234 | 
            +
                    raise ValueError
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                return optimizer
         | 
| 237 | 
            +
             | 
| 238 | 
            +
             | 
| 239 | 
            +
            class SmoothedValue(object):
         | 
| 240 | 
            +
                """Track a series of values and provide access to smoothed values over a
         | 
| 241 | 
            +
                window or the global series average.
         | 
| 242 | 
            +
                """
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                def __init__(self, window_size=20, fmt=None):
         | 
| 245 | 
            +
                    if fmt is None:
         | 
| 246 | 
            +
                        fmt = "{median:.4f} ({global_avg:.4f})"
         | 
| 247 | 
            +
                    self.deque = deque(maxlen=window_size)
         | 
| 248 | 
            +
                    self.total = 0.0
         | 
| 249 | 
            +
                    self.count = 0
         | 
| 250 | 
            +
                    self.fmt = fmt
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                def update(self, value, n=1):
         | 
| 253 | 
            +
                    self.deque.append(value)
         | 
| 254 | 
            +
                    self.count += n
         | 
| 255 | 
            +
                    self.total += value * n
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                def synchronize_between_processes(self):
         | 
| 258 | 
            +
                    """
         | 
| 259 | 
            +
                    Warning: does not synchronize the deque!
         | 
| 260 | 
            +
                    """
         | 
| 261 | 
            +
                    if not is_dist_avail_and_initialized():
         | 
| 262 | 
            +
                        return
         | 
| 263 | 
            +
                    t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
         | 
| 264 | 
            +
                    dist.barrier()
         | 
| 265 | 
            +
                    dist.all_reduce(t)
         | 
| 266 | 
            +
                    t = t.tolist()
         | 
| 267 | 
            +
                    self.count = int(t[0])
         | 
| 268 | 
            +
                    self.total = t[1]
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                @property
         | 
| 271 | 
            +
                def median(self):
         | 
| 272 | 
            +
                    d = torch.tensor(list(self.deque))
         | 
| 273 | 
            +
                    return d.median().item()
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                @property
         | 
| 276 | 
            +
                def avg(self):
         | 
| 277 | 
            +
                    d = torch.tensor(list(self.deque), dtype=torch.float32)
         | 
| 278 | 
            +
                    return d.mean().item()
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                @property
         | 
| 281 | 
            +
                def global_avg(self):
         | 
| 282 | 
            +
                    return self.total / self.count
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                @property
         | 
| 285 | 
            +
                def max(self):
         | 
| 286 | 
            +
                    return max(self.deque)
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                @property
         | 
| 289 | 
            +
                def value(self):
         | 
| 290 | 
            +
                    return self.deque[-1]
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                def __str__(self):
         | 
| 293 | 
            +
                    return self.fmt.format(
         | 
| 294 | 
            +
                        median=self.median,
         | 
| 295 | 
            +
                        avg=self.avg,
         | 
| 296 | 
            +
                        global_avg=self.global_avg,
         | 
| 297 | 
            +
                        max=self.max,
         | 
| 298 | 
            +
                        value=self.value)
         | 
| 299 | 
            +
             | 
| 300 | 
            +
             | 
| 301 | 
            +
            class MetricLogger(object):
         | 
| 302 | 
            +
                def __init__(self, delimiter="\t"):
         | 
| 303 | 
            +
                    self.meters = defaultdict(SmoothedValue)
         | 
| 304 | 
            +
                    self.delimiter = delimiter
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                def update(self, **kwargs):
         | 
| 307 | 
            +
                    for k, v in kwargs.items():
         | 
| 308 | 
            +
                        if v is None:
         | 
| 309 | 
            +
                            continue
         | 
| 310 | 
            +
                        if isinstance(v, torch.Tensor):
         | 
| 311 | 
            +
                            v = v.item()
         | 
| 312 | 
            +
                        assert isinstance(v, (float, int))
         | 
| 313 | 
            +
                        self.meters[k].update(v)
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                def __getattr__(self, attr):
         | 
| 316 | 
            +
                    if attr in self.meters:
         | 
| 317 | 
            +
                        return self.meters[attr]
         | 
| 318 | 
            +
                    if attr in self.__dict__:
         | 
| 319 | 
            +
                        return self.__dict__[attr]
         | 
| 320 | 
            +
                    raise AttributeError("'{}' object has no attribute '{}'".format(
         | 
| 321 | 
            +
                        type(self).__name__, attr))
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                def __str__(self):
         | 
| 324 | 
            +
                    loss_str = []
         | 
| 325 | 
            +
                    for name, meter in self.meters.items():
         | 
| 326 | 
            +
                        loss_str.append(
         | 
| 327 | 
            +
                            "{}: {}".format(name, str(meter))
         | 
| 328 | 
            +
                        )
         | 
| 329 | 
            +
                    return self.delimiter.join(loss_str)
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                def synchronize_between_processes(self):
         | 
| 332 | 
            +
                    for meter in self.meters.values():
         | 
| 333 | 
            +
                        meter.synchronize_between_processes()
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                def add_meter(self, name, meter):
         | 
| 336 | 
            +
                    self.meters[name] = meter
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                def log_every(self, iterable, print_freq, header=None):
         | 
| 339 | 
            +
                    i = 0
         | 
| 340 | 
            +
                    if not header:
         | 
| 341 | 
            +
                        header = ''
         | 
| 342 | 
            +
                    start_time = time.time()
         | 
| 343 | 
            +
                    end = time.time()
         | 
| 344 | 
            +
                    iter_time = SmoothedValue(fmt='{avg:.4f}')
         | 
| 345 | 
            +
                    data_time = SmoothedValue(fmt='{avg:.4f}')
         | 
| 346 | 
            +
                    space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
         | 
| 347 | 
            +
                    log_msg = [
         | 
| 348 | 
            +
                        header,
         | 
| 349 | 
            +
                        '[{0' + space_fmt + '}/{1}]',
         | 
| 350 | 
            +
                        'eta: {eta}',
         | 
| 351 | 
            +
                        '{meters}',
         | 
| 352 | 
            +
                        'time: {time}',
         | 
| 353 | 
            +
                        'data: {data}'
         | 
| 354 | 
            +
                    ]
         | 
| 355 | 
            +
                    if torch.cuda.is_available():
         | 
| 356 | 
            +
                        log_msg.append('max mem: {memory:.0f}')
         | 
| 357 | 
            +
                    log_msg = self.delimiter.join(log_msg)
         | 
| 358 | 
            +
                    MB = 1024.0 * 1024.0
         | 
| 359 | 
            +
                    for obj in iterable:
         | 
| 360 | 
            +
                        data_time.update(time.time() - end)
         | 
| 361 | 
            +
                        yield obj
         | 
| 362 | 
            +
                        iter_time.update(time.time() - end)
         | 
| 363 | 
            +
                        if i % print_freq == 0 or i == len(iterable) - 1:
         | 
| 364 | 
            +
                            eta_seconds = iter_time.global_avg * (len(iterable) - i)
         | 
| 365 | 
            +
                            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
         | 
| 366 | 
            +
                            if torch.cuda.is_available():
         | 
| 367 | 
            +
                                print(log_msg.format(
         | 
| 368 | 
            +
                                    i, len(iterable), eta=eta_string,
         | 
| 369 | 
            +
                                    meters=str(self),
         | 
| 370 | 
            +
                                    time=str(iter_time), data=str(data_time),
         | 
| 371 | 
            +
                                    memory=torch.cuda.max_memory_allocated() / MB))
         | 
| 372 | 
            +
                            else:
         | 
| 373 | 
            +
                                print(log_msg.format(
         | 
| 374 | 
            +
                                    i, len(iterable), eta=eta_string,
         | 
| 375 | 
            +
                                    meters=str(self),
         | 
| 376 | 
            +
                                    time=str(iter_time), data=str(data_time)))
         | 
| 377 | 
            +
                        i += 1
         | 
| 378 | 
            +
                        end = time.time()
         | 
| 379 | 
            +
                    total_time = time.time() - start_time
         | 
| 380 | 
            +
                    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
         | 
| 381 | 
            +
                    print('{} Total time: {} ({:.4f} s / it)'.format(
         | 
| 382 | 
            +
                        header, total_time_str, total_time / len(iterable)))
         | 
    	
        utils.py
    ADDED
    
    | @@ -0,0 +1,457 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import PIL.Image
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            from torch import nn
         | 
| 6 | 
            +
            import torch.distributed as dist
         | 
| 7 | 
            +
            import timm.models.hub as timm_hub
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            """Modified from https://github.com/CompVis/taming-transformers.git"""
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import hashlib
         | 
| 12 | 
            +
            import requests
         | 
| 13 | 
            +
            from tqdm import tqdm
         | 
| 14 | 
            +
            try:
         | 
| 15 | 
            +
                import piq
         | 
| 16 | 
            +
            except:
         | 
| 17 | 
            +
                pass
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            _CONTEXT_PARALLEL_GROUP = None
         | 
| 20 | 
            +
            _CONTEXT_PARALLEL_SIZE = None
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            def is_dist_avail_and_initialized():
         | 
| 24 | 
            +
                if not dist.is_available():
         | 
| 25 | 
            +
                    return False
         | 
| 26 | 
            +
                if not dist.is_initialized():
         | 
| 27 | 
            +
                    return False
         | 
| 28 | 
            +
                return True
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def get_world_size():
         | 
| 32 | 
            +
                if not is_dist_avail_and_initialized():
         | 
| 33 | 
            +
                    return 1
         | 
| 34 | 
            +
                return dist.get_world_size()
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def get_rank():
         | 
| 38 | 
            +
                if not is_dist_avail_and_initialized():
         | 
| 39 | 
            +
                    return 0
         | 
| 40 | 
            +
                return dist.get_rank()
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            def is_main_process():
         | 
| 44 | 
            +
                return get_rank() == 0
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            def is_context_parallel_initialized():
         | 
| 48 | 
            +
                if _CONTEXT_PARALLEL_GROUP is None:
         | 
| 49 | 
            +
                    return False
         | 
| 50 | 
            +
                else:
         | 
| 51 | 
            +
                    return True
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            def set_context_parallel_group(size, group):
         | 
| 55 | 
            +
                global _CONTEXT_PARALLEL_GROUP
         | 
| 56 | 
            +
                global _CONTEXT_PARALLEL_SIZE
         | 
| 57 | 
            +
                _CONTEXT_PARALLEL_GROUP = group
         | 
| 58 | 
            +
                _CONTEXT_PARALLEL_SIZE = size
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            def initialize_context_parallel(context_parallel_size):
         | 
| 62 | 
            +
                global _CONTEXT_PARALLEL_GROUP
         | 
| 63 | 
            +
                global _CONTEXT_PARALLEL_SIZE
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                assert _CONTEXT_PARALLEL_GROUP is None, "context parallel group is already initialized"
         | 
| 66 | 
            +
                _CONTEXT_PARALLEL_SIZE = context_parallel_size
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                rank = torch.distributed.get_rank()
         | 
| 69 | 
            +
                world_size = torch.distributed.get_world_size()
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                for i in range(0, world_size, context_parallel_size):
         | 
| 72 | 
            +
                    ranks = range(i, i + context_parallel_size)
         | 
| 73 | 
            +
                    group = torch.distributed.new_group(ranks)
         | 
| 74 | 
            +
                    if rank in ranks:
         | 
| 75 | 
            +
                        _CONTEXT_PARALLEL_GROUP = group
         | 
| 76 | 
            +
                        break
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            def get_context_parallel_group():
         | 
| 80 | 
            +
                assert _CONTEXT_PARALLEL_GROUP is not None, "context parallel group is not initialized"
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                return _CONTEXT_PARALLEL_GROUP
         | 
| 83 | 
            +
             | 
| 84 | 
            +
             | 
| 85 | 
            +
            def get_context_parallel_world_size():
         | 
| 86 | 
            +
                assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                return _CONTEXT_PARALLEL_SIZE
         | 
| 89 | 
            +
             | 
| 90 | 
            +
             | 
| 91 | 
            +
            def get_context_parallel_rank():
         | 
| 92 | 
            +
                assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                rank = get_rank()
         | 
| 95 | 
            +
                cp_rank = rank % _CONTEXT_PARALLEL_SIZE
         | 
| 96 | 
            +
                return cp_rank
         | 
| 97 | 
            +
             | 
| 98 | 
            +
             | 
| 99 | 
            +
            def get_context_parallel_group_rank():
         | 
| 100 | 
            +
                assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                rank = get_rank()
         | 
| 103 | 
            +
                cp_group_rank = rank // _CONTEXT_PARALLEL_SIZE
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                return cp_group_rank
         | 
| 106 | 
            +
             | 
| 107 | 
            +
             | 
| 108 | 
            +
            def download_cached_file(url, check_hash=True, progress=False):
         | 
| 109 | 
            +
                """
         | 
| 110 | 
            +
                Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
         | 
| 111 | 
            +
                If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
         | 
| 112 | 
            +
                """
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                def get_cached_file_path():
         | 
| 115 | 
            +
                    # a hack to sync the file path across processes
         | 
| 116 | 
            +
                    parts = torch.hub.urlparse(url)
         | 
| 117 | 
            +
                    filename = os.path.basename(parts.path)
         | 
| 118 | 
            +
                    cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    return cached_file
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                if is_main_process():
         | 
| 123 | 
            +
                    timm_hub.download_cached_file(url, check_hash, progress)
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                if is_dist_avail_and_initialized():
         | 
| 126 | 
            +
                    dist.barrier()
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                return get_cached_file_path()
         | 
| 129 | 
            +
             | 
| 130 | 
            +
             | 
| 131 | 
            +
            def convert_weights_to_fp16(model: nn.Module):
         | 
| 132 | 
            +
                """Convert applicable model parameters to fp16"""
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                def _convert_weights_to_fp16(l):
         | 
| 135 | 
            +
                    if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
         | 
| 136 | 
            +
                        l.weight.data = l.weight.data.to(torch.float16)
         | 
| 137 | 
            +
                        if l.bias is not None:
         | 
| 138 | 
            +
                            l.bias.data = l.bias.data.to(torch.float16)
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                model.apply(_convert_weights_to_fp16)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
             | 
| 143 | 
            +
            def convert_weights_to_bf16(model: nn.Module):
         | 
| 144 | 
            +
                """Convert applicable model parameters to fp16"""
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                def _convert_weights_to_bf16(l):
         | 
| 147 | 
            +
                    if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
         | 
| 148 | 
            +
                        l.weight.data = l.weight.data.to(torch.bfloat16)
         | 
| 149 | 
            +
                        if l.bias is not None:
         | 
| 150 | 
            +
                            l.bias.data = l.bias.data.to(torch.bfloat16)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                model.apply(_convert_weights_to_bf16)
         | 
| 153 | 
            +
             | 
| 154 | 
            +
             | 
| 155 | 
            +
            def save_result(result, result_dir, filename, remove_duplicate="", save_format='json'):
         | 
| 156 | 
            +
                import json
         | 
| 157 | 
            +
                import jsonlines
         | 
| 158 | 
            +
                print("Dump result")
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                # Make the temp dir for saving results
         | 
| 161 | 
            +
                if not os.path.exists(result_dir):
         | 
| 162 | 
            +
                    if is_main_process():
         | 
| 163 | 
            +
                        os.makedirs(result_dir)
         | 
| 164 | 
            +
                    if is_dist_avail_and_initialized():
         | 
| 165 | 
            +
                        torch.distributed.barrier()
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                result_file = os.path.join(
         | 
| 168 | 
            +
                    result_dir, "%s_rank%d.json" % (filename, get_rank())
         | 
| 169 | 
            +
                )
         | 
| 170 | 
            +
                
         | 
| 171 | 
            +
                final_result_file = os.path.join(result_dir, f"{filename}.{save_format}")
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                json.dump(result, open(result_file, "w"))
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                if is_dist_avail_and_initialized():
         | 
| 176 | 
            +
                    torch.distributed.barrier()
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                if is_main_process():
         | 
| 179 | 
            +
                    # print("rank %d starts merging results." % get_rank())
         | 
| 180 | 
            +
                    # combine results from all processes
         | 
| 181 | 
            +
                    result = []
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    for rank in range(get_world_size()):
         | 
| 184 | 
            +
                        result_file = os.path.join(result_dir, "%s_rank%d.json" % (filename, rank))
         | 
| 185 | 
            +
                        res = json.load(open(result_file, "r"))
         | 
| 186 | 
            +
                        result += res
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    # print("Remove duplicate")
         | 
| 189 | 
            +
                    if remove_duplicate:
         | 
| 190 | 
            +
                        result_new = []
         | 
| 191 | 
            +
                        id_set = set()
         | 
| 192 | 
            +
                        for res in result:
         | 
| 193 | 
            +
                            if res[remove_duplicate] not in id_set:
         | 
| 194 | 
            +
                                id_set.add(res[remove_duplicate])
         | 
| 195 | 
            +
                                result_new.append(res)
         | 
| 196 | 
            +
                        result = result_new
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                    if save_format == 'json':
         | 
| 199 | 
            +
                        json.dump(result, open(final_result_file, "w"))
         | 
| 200 | 
            +
                    else:
         | 
| 201 | 
            +
                        assert save_format == 'jsonl', "Only support json adn jsonl format"
         | 
| 202 | 
            +
                        with jsonlines.open(final_result_file, "w") as writer:
         | 
| 203 | 
            +
                            writer.write_all(result)
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    # print("result file saved to %s" % final_result_file)
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                return final_result_file
         | 
| 208 | 
            +
             | 
| 209 | 
            +
             | 
| 210 | 
            +
            # resizing utils
         | 
| 211 | 
            +
            # TODO: clean up later
         | 
| 212 | 
            +
            def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
         | 
| 213 | 
            +
                h, w = input.shape[-2:]
         | 
| 214 | 
            +
                factors = (h / size[0], w / size[1])
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                # First, we have to determine sigma
         | 
| 217 | 
            +
                # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
         | 
| 218 | 
            +
                sigmas = (
         | 
| 219 | 
            +
                    max((factors[0] - 1.0) / 2.0, 0.001),
         | 
| 220 | 
            +
                    max((factors[1] - 1.0) / 2.0, 0.001),
         | 
| 221 | 
            +
                )
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
         | 
| 224 | 
            +
                # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
         | 
| 225 | 
            +
                # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
         | 
| 226 | 
            +
                ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                # Make sure it is odd
         | 
| 229 | 
            +
                if (ks[0] % 2) == 0:
         | 
| 230 | 
            +
                    ks = ks[0] + 1, ks[1]
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                if (ks[1] % 2) == 0:
         | 
| 233 | 
            +
                    ks = ks[0], ks[1] + 1
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                input = _gaussian_blur2d(input, ks, sigmas)
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
         | 
| 238 | 
            +
                return output
         | 
| 239 | 
            +
             | 
| 240 | 
            +
             | 
| 241 | 
            +
            def _compute_padding(kernel_size):
         | 
| 242 | 
            +
                """Compute padding tuple."""
         | 
| 243 | 
            +
                # 4 or 6 ints:  (padding_left, padding_right,padding_top,padding_bottom)
         | 
| 244 | 
            +
                # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
         | 
| 245 | 
            +
                if len(kernel_size) < 2:
         | 
| 246 | 
            +
                    raise AssertionError(kernel_size)
         | 
| 247 | 
            +
                computed = [k - 1 for k in kernel_size]
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                # for even kernels we need to do asymmetric padding :(
         | 
| 250 | 
            +
                out_padding = 2 * len(kernel_size) * [0]
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                for i in range(len(kernel_size)):
         | 
| 253 | 
            +
                    computed_tmp = computed[-(i + 1)]
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                    pad_front = computed_tmp // 2
         | 
| 256 | 
            +
                    pad_rear = computed_tmp - pad_front
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                    out_padding[2 * i + 0] = pad_front
         | 
| 259 | 
            +
                    out_padding[2 * i + 1] = pad_rear
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                return out_padding
         | 
| 262 | 
            +
             | 
| 263 | 
            +
             | 
| 264 | 
            +
            def _filter2d(input, kernel):
         | 
| 265 | 
            +
                # prepare kernel
         | 
| 266 | 
            +
                b, c, h, w = input.shape
         | 
| 267 | 
            +
                tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                height, width = tmp_kernel.shape[-2:]
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                padding_shape: list[int] = _compute_padding([height, width])
         | 
| 274 | 
            +
                input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                # kernel and input tensor reshape to align element-wise or batch-wise params
         | 
| 277 | 
            +
                tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
         | 
| 278 | 
            +
                input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                # convolve the tensor with the kernel.
         | 
| 281 | 
            +
                output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                out = output.view(b, c, h, w)
         | 
| 284 | 
            +
                return out
         | 
| 285 | 
            +
             | 
| 286 | 
            +
             | 
| 287 | 
            +
            def _gaussian(window_size: int, sigma):
         | 
| 288 | 
            +
                if isinstance(sigma, float):
         | 
| 289 | 
            +
                    sigma = torch.tensor([[sigma]])
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                batch_size = sigma.shape[0]
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                if window_size % 2 == 0:
         | 
| 296 | 
            +
                    x = x + 0.5
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                return gauss / gauss.sum(-1, keepdim=True)
         | 
| 301 | 
            +
             | 
| 302 | 
            +
             | 
| 303 | 
            +
            def _gaussian_blur2d(input, kernel_size, sigma):
         | 
| 304 | 
            +
                if isinstance(sigma, tuple):
         | 
| 305 | 
            +
                    sigma = torch.tensor([sigma], dtype=input.dtype)
         | 
| 306 | 
            +
                else:
         | 
| 307 | 
            +
                    sigma = sigma.to(dtype=input.dtype)
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                ky, kx = int(kernel_size[0]), int(kernel_size[1])
         | 
| 310 | 
            +
                bs = sigma.shape[0]
         | 
| 311 | 
            +
                kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
         | 
| 312 | 
            +
                kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
         | 
| 313 | 
            +
                out_x = _filter2d(input, kernel_x[..., None, :])
         | 
| 314 | 
            +
                out = _filter2d(out_x, kernel_y[..., None])
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                return out
         | 
| 317 | 
            +
             | 
| 318 | 
            +
             | 
| 319 | 
            +
            URL_MAP = {
         | 
| 320 | 
            +
                "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
         | 
| 321 | 
            +
            }
         | 
| 322 | 
            +
             | 
| 323 | 
            +
            CKPT_MAP = {
         | 
| 324 | 
            +
                "vgg_lpips": "vgg.pth"
         | 
| 325 | 
            +
            }
         | 
| 326 | 
            +
             | 
| 327 | 
            +
            MD5_MAP = {
         | 
| 328 | 
            +
                "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
         | 
| 329 | 
            +
            }
         | 
| 330 | 
            +
             | 
| 331 | 
            +
             | 
| 332 | 
            +
            def download(url, local_path, chunk_size=1024):
         | 
| 333 | 
            +
                os.makedirs(os.path.split(local_path)[0], exist_ok=True)
         | 
| 334 | 
            +
                with requests.get(url, stream=True) as r:
         | 
| 335 | 
            +
                    total_size = int(r.headers.get("content-length", 0))
         | 
| 336 | 
            +
                    with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
         | 
| 337 | 
            +
                        with open(local_path, "wb") as f:
         | 
| 338 | 
            +
                            for data in r.iter_content(chunk_size=chunk_size):
         | 
| 339 | 
            +
                                if data:
         | 
| 340 | 
            +
                                    f.write(data)
         | 
| 341 | 
            +
                                    pbar.update(chunk_size)
         | 
| 342 | 
            +
             | 
| 343 | 
            +
             | 
| 344 | 
            +
            def md5_hash(path):
         | 
| 345 | 
            +
                with open(path, "rb") as f:
         | 
| 346 | 
            +
                    content = f.read()
         | 
| 347 | 
            +
                return hashlib.md5(content).hexdigest()
         | 
| 348 | 
            +
             | 
| 349 | 
            +
             | 
| 350 | 
            +
            def get_ckpt_path(name, root, check=False):
         | 
| 351 | 
            +
                assert name in URL_MAP
         | 
| 352 | 
            +
                path = os.path.join(root, CKPT_MAP[name])
         | 
| 353 | 
            +
                print(md5_hash(path))
         | 
| 354 | 
            +
                if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
         | 
| 355 | 
            +
                    print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
         | 
| 356 | 
            +
                    download(URL_MAP[name], path)
         | 
| 357 | 
            +
                    md5 = md5_hash(path)
         | 
| 358 | 
            +
                    assert md5 == MD5_MAP[name], md5
         | 
| 359 | 
            +
                return path
         | 
| 360 | 
            +
             | 
| 361 | 
            +
             | 
| 362 | 
            +
            class KeyNotFoundError(Exception):
         | 
| 363 | 
            +
                def __init__(self, cause, keys=None, visited=None):
         | 
| 364 | 
            +
                    self.cause = cause
         | 
| 365 | 
            +
                    self.keys = keys
         | 
| 366 | 
            +
                    self.visited = visited
         | 
| 367 | 
            +
                    messages = list()
         | 
| 368 | 
            +
                    if keys is not None:
         | 
| 369 | 
            +
                        messages.append("Key not found: {}".format(keys))
         | 
| 370 | 
            +
                    if visited is not None:
         | 
| 371 | 
            +
                        messages.append("Visited: {}".format(visited))
         | 
| 372 | 
            +
                    messages.append("Cause:\n{}".format(cause))
         | 
| 373 | 
            +
                    message = "\n".join(messages)
         | 
| 374 | 
            +
                    super().__init__(message)
         | 
| 375 | 
            +
             | 
| 376 | 
            +
             | 
| 377 | 
            +
            def retrieve(
         | 
| 378 | 
            +
                list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
         | 
| 379 | 
            +
            ):
         | 
| 380 | 
            +
                """Given a nested list or dict return the desired value at key expanding
         | 
| 381 | 
            +
                callable nodes if necessary and :attr:`expand` is ``True``. The expansion
         | 
| 382 | 
            +
                is done in-place.
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                Parameters
         | 
| 385 | 
            +
                ----------
         | 
| 386 | 
            +
                    list_or_dict : list or dict
         | 
| 387 | 
            +
                        Possibly nested list or dictionary.
         | 
| 388 | 
            +
                    key : str
         | 
| 389 | 
            +
                        key/to/value, path like string describing all keys necessary to
         | 
| 390 | 
            +
                        consider to get to the desired value. List indices can also be
         | 
| 391 | 
            +
                        passed here.
         | 
| 392 | 
            +
                    splitval : str
         | 
| 393 | 
            +
                        String that defines the delimiter between keys of the
         | 
| 394 | 
            +
                        different depth levels in `key`.
         | 
| 395 | 
            +
                    default : obj
         | 
| 396 | 
            +
                        Value returned if :attr:`key` is not found.
         | 
| 397 | 
            +
                    expand : bool
         | 
| 398 | 
            +
                        Whether to expand callable nodes on the path or not.
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                Returns
         | 
| 401 | 
            +
                -------
         | 
| 402 | 
            +
                    The desired value or if :attr:`default` is not ``None`` and the
         | 
| 403 | 
            +
                    :attr:`key` is not found returns ``default``.
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                Raises
         | 
| 406 | 
            +
                ------
         | 
| 407 | 
            +
                    Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
         | 
| 408 | 
            +
                    ``None``.
         | 
| 409 | 
            +
                """
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                keys = key.split(splitval)
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                success = True
         | 
| 414 | 
            +
                try:
         | 
| 415 | 
            +
                    visited = []
         | 
| 416 | 
            +
                    parent = None
         | 
| 417 | 
            +
                    last_key = None
         | 
| 418 | 
            +
                    for key in keys:
         | 
| 419 | 
            +
                        if callable(list_or_dict):
         | 
| 420 | 
            +
                            if not expand:
         | 
| 421 | 
            +
                                raise KeyNotFoundError(
         | 
| 422 | 
            +
                                    ValueError(
         | 
| 423 | 
            +
                                        "Trying to get past callable node with expand=False."
         | 
| 424 | 
            +
                                    ),
         | 
| 425 | 
            +
                                    keys=keys,
         | 
| 426 | 
            +
                                    visited=visited,
         | 
| 427 | 
            +
                                )
         | 
| 428 | 
            +
                            list_or_dict = list_or_dict()
         | 
| 429 | 
            +
                            parent[last_key] = list_or_dict
         | 
| 430 | 
            +
             | 
| 431 | 
            +
                        last_key = key
         | 
| 432 | 
            +
                        parent = list_or_dict
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                        try:
         | 
| 435 | 
            +
                            if isinstance(list_or_dict, dict):
         | 
| 436 | 
            +
                                list_or_dict = list_or_dict[key]
         | 
| 437 | 
            +
                            else:
         | 
| 438 | 
            +
                                list_or_dict = list_or_dict[int(key)]
         | 
| 439 | 
            +
                        except (KeyError, IndexError, ValueError) as e:
         | 
| 440 | 
            +
                            raise KeyNotFoundError(e, keys=keys, visited=visited)
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                        visited += [key]
         | 
| 443 | 
            +
                    # final expansion of retrieved value
         | 
| 444 | 
            +
                    if expand and callable(list_or_dict):
         | 
| 445 | 
            +
                        list_or_dict = list_or_dict()
         | 
| 446 | 
            +
                        parent[last_key] = list_or_dict
         | 
| 447 | 
            +
                except KeyNotFoundError as e:
         | 
| 448 | 
            +
                    if default is None:
         | 
| 449 | 
            +
                        raise e
         | 
| 450 | 
            +
                    else:
         | 
| 451 | 
            +
                        list_or_dict = default
         | 
| 452 | 
            +
                        success = False
         | 
| 453 | 
            +
             | 
| 454 | 
            +
                if not pass_success:
         | 
| 455 | 
            +
                    return list_or_dict
         | 
| 456 | 
            +
                else:
         | 
| 457 | 
            +
                    return list_or_dict, success
         | 
    	
        video_generation_demo.ipynb
    ADDED
    
    | @@ -0,0 +1,181 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
             "cells": [
         | 
| 3 | 
            +
              {
         | 
| 4 | 
            +
               "cell_type": "code",
         | 
| 5 | 
            +
               "execution_count": null,
         | 
| 6 | 
            +
               "metadata": {},
         | 
| 7 | 
            +
               "outputs": [],
         | 
| 8 | 
            +
               "source": [
         | 
| 9 | 
            +
                "import os\n",
         | 
| 10 | 
            +
                "import json\n",
         | 
| 11 | 
            +
                "import torch\n",
         | 
| 12 | 
            +
                "import numpy as np\n",
         | 
| 13 | 
            +
                "import PIL\n",
         | 
| 14 | 
            +
                "from PIL import Image\n",
         | 
| 15 | 
            +
                "from IPython.display import HTML\n",
         | 
| 16 | 
            +
                "from pyramid_dit import PyramidDiTForVideoGeneration\n",
         | 
| 17 | 
            +
                "from IPython.display import Image as ipython_image\n",
         | 
| 18 | 
            +
                "from diffusers.utils import load_image, export_to_video, export_to_gif"
         | 
| 19 | 
            +
               ]
         | 
| 20 | 
            +
              },
         | 
| 21 | 
            +
              {
         | 
| 22 | 
            +
               "cell_type": "code",
         | 
| 23 | 
            +
               "execution_count": null,
         | 
| 24 | 
            +
               "metadata": {},
         | 
| 25 | 
            +
               "outputs": [],
         | 
| 26 | 
            +
               "source": [
         | 
| 27 | 
            +
                "variant='diffusion_transformer_768p'         # For high resolution\n",
         | 
| 28 | 
            +
                "# variant='diffusion_transformer_384p'       # For low resolution\n",
         | 
| 29 | 
            +
                "\n",
         | 
| 30 | 
            +
                "model_path = \"/home/jinyang06/models/pyramid-flow\"   # The downloaded checkpoint dir\n",
         | 
| 31 | 
            +
                "model_dtype = 'bf16'\n",
         | 
| 32 | 
            +
                "\n",
         | 
| 33 | 
            +
                "device_id = 0\n",
         | 
| 34 | 
            +
                "torch.cuda.set_device(device_id)\n",
         | 
| 35 | 
            +
                "\n",
         | 
| 36 | 
            +
                "model = PyramidDiTForVideoGeneration(\n",
         | 
| 37 | 
            +
                "    model_path,\n",
         | 
| 38 | 
            +
                "    model_dtype,\n",
         | 
| 39 | 
            +
                "    model_variant=variant,\n",
         | 
| 40 | 
            +
                ")\n",
         | 
| 41 | 
            +
                "\n",
         | 
| 42 | 
            +
                "model.vae.to(\"cuda\")\n",
         | 
| 43 | 
            +
                "model.dit.to(\"cuda\")\n",
         | 
| 44 | 
            +
                "model.text_encoder.to(\"cuda\")\n",
         | 
| 45 | 
            +
                "\n",
         | 
| 46 | 
            +
                "if model_dtype == \"bf16\":\n",
         | 
| 47 | 
            +
                "    torch_dtype = torch.bfloat16 \n",
         | 
| 48 | 
            +
                "elif model_dtype == \"fp16\":\n",
         | 
| 49 | 
            +
                "    torch_dtype = torch.float16\n",
         | 
| 50 | 
            +
                "else:\n",
         | 
| 51 | 
            +
                "    torch_dtype = torch.float32\n",
         | 
| 52 | 
            +
                "\n",
         | 
| 53 | 
            +
                "\n",
         | 
| 54 | 
            +
                "def show_video(ori_path, rec_path, width=\"100%\"):\n",
         | 
| 55 | 
            +
                "    html = ''\n",
         | 
| 56 | 
            +
                "    if ori_path is not None:\n",
         | 
| 57 | 
            +
                "        html += f\"\"\"<video controls=\"\" name=\"media\" data-fullscreen-container=\"true\" width=\"{width}\">\n",
         | 
| 58 | 
            +
                "        <source src=\"{ori_path}\" type=\"video/mp4\">\n",
         | 
| 59 | 
            +
                "        </video>\n",
         | 
| 60 | 
            +
                "        \"\"\"\n",
         | 
| 61 | 
            +
                "    \n",
         | 
| 62 | 
            +
                "    html += f\"\"\"<video controls=\"\" name=\"media\" data-fullscreen-container=\"true\" width=\"{width}\">\n",
         | 
| 63 | 
            +
                "    <source src=\"{rec_path}\" type=\"video/mp4\">\n",
         | 
| 64 | 
            +
                "    </video>\n",
         | 
| 65 | 
            +
                "    \"\"\"\n",
         | 
| 66 | 
            +
                "    return HTML(html)"
         | 
| 67 | 
            +
               ]
         | 
| 68 | 
            +
              },
         | 
| 69 | 
            +
              {
         | 
| 70 | 
            +
               "attachments": {},
         | 
| 71 | 
            +
               "cell_type": "markdown",
         | 
| 72 | 
            +
               "metadata": {},
         | 
| 73 | 
            +
               "source": [
         | 
| 74 | 
            +
                "#### Text-to-Video"
         | 
| 75 | 
            +
               ]
         | 
| 76 | 
            +
              },
         | 
| 77 | 
            +
              {
         | 
| 78 | 
            +
               "cell_type": "code",
         | 
| 79 | 
            +
               "execution_count": null,
         | 
| 80 | 
            +
               "metadata": {},
         | 
| 81 | 
            +
               "outputs": [],
         | 
| 82 | 
            +
               "source": [
         | 
| 83 | 
            +
                "prompt = \"A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors\"\n",
         | 
| 84 | 
            +
                "\n",
         | 
| 85 | 
            +
                "# used for 384p model variant\n",
         | 
| 86 | 
            +
                "# width = 640\n",
         | 
| 87 | 
            +
                "# height = 384\n",
         | 
| 88 | 
            +
                "\n",
         | 
| 89 | 
            +
                "# used for 768p model variant\n",
         | 
| 90 | 
            +
                "width = 1280\n",
         | 
| 91 | 
            +
                "height = 768\n",
         | 
| 92 | 
            +
                "\n",
         | 
| 93 | 
            +
                "temp = 16   # temp in [1, 31] <=> frame in [1, 241] <=> duration in [0, 10s]\n",
         | 
| 94 | 
            +
                "\n",
         | 
| 95 | 
            +
                "model.vae.enable_tiling()\n",
         | 
| 96 | 
            +
                "\n",
         | 
| 97 | 
            +
                "with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):\n",
         | 
| 98 | 
            +
                "    frames = model.generate(\n",
         | 
| 99 | 
            +
                "        prompt=prompt,\n",
         | 
| 100 | 
            +
                "        num_inference_steps=[20, 20, 20],\n",
         | 
| 101 | 
            +
                "        video_num_inference_steps=[10, 10, 10],\n",
         | 
| 102 | 
            +
                "        height=height,\n",
         | 
| 103 | 
            +
                "        width=width,\n",
         | 
| 104 | 
            +
                "        temp=temp,\n",
         | 
| 105 | 
            +
                "        guidance_scale=9.0,         # The guidance for the first frame\n",
         | 
| 106 | 
            +
                "        video_guidance_scale=5.0,   # The guidance for the other video latent\n",
         | 
| 107 | 
            +
                "        output_type=\"pil\",\n",
         | 
| 108 | 
            +
                "        save_memory=True,           # If you have enough GPU memory, set it to `False` to improve vae decoding speed\n",
         | 
| 109 | 
            +
                "    )\n",
         | 
| 110 | 
            +
                "\n",
         | 
| 111 | 
            +
                "export_to_video(frames, \"./text_to_video_sample.mp4\", fps=24)\n",
         | 
| 112 | 
            +
                "show_video(None, \"./text_to_video_sample.mp4\", \"70%\")"
         | 
| 113 | 
            +
               ]
         | 
| 114 | 
            +
              },
         | 
| 115 | 
            +
              {
         | 
| 116 | 
            +
               "attachments": {},
         | 
| 117 | 
            +
               "cell_type": "markdown",
         | 
| 118 | 
            +
               "metadata": {},
         | 
| 119 | 
            +
               "source": [
         | 
| 120 | 
            +
                "#### Image-to-Video"
         | 
| 121 | 
            +
               ]
         | 
| 122 | 
            +
              },
         | 
| 123 | 
            +
              {
         | 
| 124 | 
            +
               "cell_type": "code",
         | 
| 125 | 
            +
               "execution_count": null,
         | 
| 126 | 
            +
               "metadata": {},
         | 
| 127 | 
            +
               "outputs": [],
         | 
| 128 | 
            +
               "source": [
         | 
| 129 | 
            +
                "image_path = 'assets/the_great_wall.jpg'\n",
         | 
| 130 | 
            +
                "image = Image.open(image_path).convert(\"RGB\")\n",
         | 
| 131 | 
            +
                "\n",
         | 
| 132 | 
            +
                "width = 1280\n",
         | 
| 133 | 
            +
                "height = 768\n",
         | 
| 134 | 
            +
                "temp = 16\n",
         | 
| 135 | 
            +
                "\n",
         | 
| 136 | 
            +
                "image = image.resize((width, height))\n",
         | 
| 137 | 
            +
                "\n",
         | 
| 138 | 
            +
                "display(image)\n",
         | 
| 139 | 
            +
                "\n",
         | 
| 140 | 
            +
                "prompt = \"FPV flying over the Great Wall\"\n",
         | 
| 141 | 
            +
                "\n",
         | 
| 142 | 
            +
                "with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):\n",
         | 
| 143 | 
            +
                "    frames = model.generate_i2v(\n",
         | 
| 144 | 
            +
                "        prompt=prompt,\n",
         | 
| 145 | 
            +
                "        input_image=image,\n",
         | 
| 146 | 
            +
                "        num_inference_steps=[10, 10, 10],\n",
         | 
| 147 | 
            +
                "        temp=temp,\n",
         | 
| 148 | 
            +
                "        guidance_scale=7.0,\n",
         | 
| 149 | 
            +
                "        video_guidance_scale=4.0,\n",
         | 
| 150 | 
            +
                "        output_type=\"pil\",\n",
         | 
| 151 | 
            +
                "        save_memory=True,         # If you have enough GPU memory, set it to `False` to improve vae decoding speed\n",
         | 
| 152 | 
            +
                "    )\n",
         | 
| 153 | 
            +
                "\n",
         | 
| 154 | 
            +
                "export_to_video(frames, \"./image_to_video_sample.mp4\", fps=24)\n",
         | 
| 155 | 
            +
                "show_video(None, \"./image_to_video_sample.mp4\", \"70%\")"
         | 
| 156 | 
            +
               ]
         | 
| 157 | 
            +
              }
         | 
| 158 | 
            +
             ],
         | 
| 159 | 
            +
             "metadata": {
         | 
| 160 | 
            +
              "kernelspec": {
         | 
| 161 | 
            +
               "display_name": "Python 3",
         | 
| 162 | 
            +
               "language": "python",
         | 
| 163 | 
            +
               "name": "python3"
         | 
| 164 | 
            +
              },
         | 
| 165 | 
            +
              "language_info": {
         | 
| 166 | 
            +
               "codemirror_mode": {
         | 
| 167 | 
            +
                "name": "ipython",
         | 
| 168 | 
            +
                "version": 3
         | 
| 169 | 
            +
               },
         | 
| 170 | 
            +
               "file_extension": ".py",
         | 
| 171 | 
            +
               "mimetype": "text/x-python",
         | 
| 172 | 
            +
               "name": "python",
         | 
| 173 | 
            +
               "nbconvert_exporter": "python",
         | 
| 174 | 
            +
               "pygments_lexer": "ipython3",
         | 
| 175 | 
            +
               "version": "3.8.10"
         | 
| 176 | 
            +
              },
         | 
| 177 | 
            +
              "orig_nbformat": 4
         | 
| 178 | 
            +
             },
         | 
| 179 | 
            +
             "nbformat": 4,
         | 
| 180 | 
            +
             "nbformat_minor": 2
         | 
| 181 | 
            +
            }
         | 
    	
        video_vae/__init__.py
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .modeling_loss import LPIPSWithDiscriminator
         | 
| 2 | 
            +
            from .modeling_causal_vae import CausalVideoVAE
         | 
    	
        video_vae/context_parallel_ops.py
    ADDED
    
    | @@ -0,0 +1,172 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # from cogvideoX
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.nn as nn
         | 
| 4 | 
            +
            import math
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from utils import (
         | 
| 7 | 
            +
                get_context_parallel_group,
         | 
| 8 | 
            +
                get_context_parallel_rank,
         | 
| 9 | 
            +
                get_context_parallel_world_size,
         | 
| 10 | 
            +
                get_context_parallel_group_rank,
         | 
| 11 | 
            +
            )
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            def _conv_split(input_, dim=2, kernel_size=1):
         | 
| 15 | 
            +
                cp_world_size = get_context_parallel_world_size()
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                # Bypass the function if context parallel is 1
         | 
| 18 | 
            +
                if cp_world_size == 1:
         | 
| 19 | 
            +
                    return input_
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                # print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                cp_rank = get_context_parallel_rank()
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                dim_size = (input_.size()[dim] - kernel_size) // cp_world_size
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                if cp_rank == 0:
         | 
| 28 | 
            +
                    output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0)
         | 
| 29 | 
            +
                else:
         | 
| 30 | 
            +
                    # output = input_.transpose(dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0)
         | 
| 31 | 
            +
                    output = input_.transpose(dim, 0)[
         | 
| 32 | 
            +
                        cp_rank * dim_size + kernel_size : (cp_rank + 1) * dim_size + kernel_size
         | 
| 33 | 
            +
                    ].transpose(dim, 0)
         | 
| 34 | 
            +
                output = output.contiguous()
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                # print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                return output
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            def _conv_gather(input_, dim=2, kernel_size=1):
         | 
| 42 | 
            +
                cp_world_size = get_context_parallel_world_size()
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                # Bypass the function if context parallel is 1
         | 
| 45 | 
            +
                if cp_world_size == 1:
         | 
| 46 | 
            +
                    return input_
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                group = get_context_parallel_group()
         | 
| 49 | 
            +
                cp_rank = get_context_parallel_rank()
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                # print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous()
         | 
| 54 | 
            +
                if cp_rank == 0:
         | 
| 55 | 
            +
                    input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous()
         | 
| 56 | 
            +
                else:
         | 
| 57 | 
            +
                    input_ = input_.transpose(0, dim)[max(kernel_size - 1, 0) :].transpose(0, dim).contiguous()
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [
         | 
| 60 | 
            +
                    torch.empty_like(input_) for _ in range(cp_world_size - 1)
         | 
| 61 | 
            +
                ]
         | 
| 62 | 
            +
                if cp_rank == 0:
         | 
| 63 | 
            +
                    input_ = torch.cat([input_first_kernel_, input_], dim=dim)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                tensor_list[cp_rank] = input_
         | 
| 66 | 
            +
                torch.distributed.all_gather(tensor_list, input_, group=group)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                # Note: torch.cat already creates a contiguous tensor.
         | 
| 69 | 
            +
                output = torch.cat(tensor_list, dim=dim).contiguous()
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                # print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                return output
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            def _cp_pass_from_previous_rank(input_, dim, kernel_size):
         | 
| 77 | 
            +
                # Bypass the function if kernel size is 1
         | 
| 78 | 
            +
                if kernel_size == 1:
         | 
| 79 | 
            +
                    return input_
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                group = get_context_parallel_group()
         | 
| 82 | 
            +
                cp_rank = get_context_parallel_rank()
         | 
| 83 | 
            +
                cp_group_rank = get_context_parallel_group_rank()
         | 
| 84 | 
            +
                cp_world_size = get_context_parallel_world_size()
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                # print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                global_rank = torch.distributed.get_rank()
         | 
| 89 | 
            +
                global_world_size = torch.distributed.get_world_size()
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                input_ = input_.transpose(0, dim)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                # pass from last rank
         | 
| 94 | 
            +
                send_rank = global_rank + 1
         | 
| 95 | 
            +
                recv_rank = global_rank - 1
         | 
| 96 | 
            +
                if send_rank % cp_world_size == 0:
         | 
| 97 | 
            +
                    send_rank -= cp_world_size
         | 
| 98 | 
            +
                if recv_rank % cp_world_size == cp_world_size - 1:
         | 
| 99 | 
            +
                    recv_rank += cp_world_size
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
         | 
| 102 | 
            +
                if cp_rank < cp_world_size - 1:
         | 
| 103 | 
            +
                    req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group)
         | 
| 104 | 
            +
                if cp_rank > 0:
         | 
| 105 | 
            +
                    req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                if cp_rank == 0:
         | 
| 108 | 
            +
                    input_ = torch.cat([torch.zeros_like(input_[:1])] * (kernel_size - 1) + [input_], dim=0)
         | 
| 109 | 
            +
                else:
         | 
| 110 | 
            +
                    req_recv.wait()
         | 
| 111 | 
            +
                    input_ = torch.cat([recv_buffer, input_], dim=0)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                input_ = input_.transpose(0, dim).contiguous()
         | 
| 114 | 
            +
                return input_
         | 
| 115 | 
            +
             | 
| 116 | 
            +
             | 
| 117 | 
            +
            def _drop_from_previous_rank(input_, dim, kernel_size):
         | 
| 118 | 
            +
                input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim)
         | 
| 119 | 
            +
                return input_
         | 
| 120 | 
            +
             | 
| 121 | 
            +
             | 
| 122 | 
            +
            class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function):
         | 
| 123 | 
            +
                @staticmethod
         | 
| 124 | 
            +
                def forward(ctx, input_, dim, kernel_size):
         | 
| 125 | 
            +
                    ctx.dim = dim
         | 
| 126 | 
            +
                    ctx.kernel_size = kernel_size
         | 
| 127 | 
            +
                    return _conv_split(input_, dim, kernel_size)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                @staticmethod
         | 
| 130 | 
            +
                def backward(ctx, grad_output):
         | 
| 131 | 
            +
                    return _conv_gather(grad_output, ctx.dim, ctx.kernel_size), None, None
         | 
| 132 | 
            +
             | 
| 133 | 
            +
             | 
| 134 | 
            +
            class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function):
         | 
| 135 | 
            +
                @staticmethod
         | 
| 136 | 
            +
                def forward(ctx, input_, dim, kernel_size):
         | 
| 137 | 
            +
                    ctx.dim = dim
         | 
| 138 | 
            +
                    ctx.kernel_size = kernel_size
         | 
| 139 | 
            +
                    return _conv_gather(input_, dim, kernel_size)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                @staticmethod
         | 
| 142 | 
            +
                def backward(ctx, grad_output):
         | 
| 143 | 
            +
                    return _conv_split(grad_output, ctx.dim, ctx.kernel_size), None, None
         | 
| 144 | 
            +
             | 
| 145 | 
            +
             | 
| 146 | 
            +
            class _CPConvolutionPassFromPreviousRank(torch.autograd.Function):
         | 
| 147 | 
            +
                @staticmethod
         | 
| 148 | 
            +
                def forward(ctx, input_, dim, kernel_size):
         | 
| 149 | 
            +
                    ctx.dim = dim
         | 
| 150 | 
            +
                    ctx.kernel_size = kernel_size
         | 
| 151 | 
            +
                    return _cp_pass_from_previous_rank(input_, dim, kernel_size)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                @staticmethod
         | 
| 154 | 
            +
                def backward(ctx, grad_output):
         | 
| 155 | 
            +
                    return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None
         | 
| 156 | 
            +
             | 
| 157 | 
            +
             | 
| 158 | 
            +
            def conv_scatter_to_context_parallel_region(input_, dim, kernel_size):
         | 
| 159 | 
            +
                return _ConvolutionScatterToContextParallelRegion.apply(input_, dim, kernel_size)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
             | 
| 162 | 
            +
            def conv_gather_from_context_parallel_region(input_, dim, kernel_size):
         | 
| 163 | 
            +
                return _ConvolutionGatherFromContextParallelRegion.apply(input_, dim, kernel_size)
         | 
| 164 | 
            +
             | 
| 165 | 
            +
             | 
| 166 | 
            +
            def cp_pass_from_previous_rank(input_, dim, kernel_size):
         | 
| 167 | 
            +
                return _CPConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
             | 
| 170 | 
            +
             | 
| 171 | 
            +
             | 
| 172 | 
            +
             | 
    	
        video_vae/modeling_block.py
    ADDED
    
    | @@ -0,0 +1,760 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
            from typing import Any, Dict, Optional, Tuple, Union
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            import numpy as np
         | 
| 17 | 
            +
            import torch
         | 
| 18 | 
            +
            import torch.nn.functional as F
         | 
| 19 | 
            +
            from torch import nn
         | 
| 20 | 
            +
            from einops import rearrange
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from diffusers.utils import logging
         | 
| 23 | 
            +
            from diffusers.models.attention_processor import Attention
         | 
| 24 | 
            +
            from .modeling_resnet import (
         | 
| 25 | 
            +
                Downsample2D, ResnetBlock2D, CausalResnetBlock3D, Upsample2D,
         | 
| 26 | 
            +
                TemporalDownsample2x, TemporalUpsample2x,
         | 
| 27 | 
            +
                CausalDownsample2x, CausalTemporalDownsample2x,
         | 
| 28 | 
            +
                CausalUpsample2x, CausalTemporalUpsample2x,
         | 
| 29 | 
            +
            )
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            def get_input_layer(
         | 
| 35 | 
            +
                in_channels: int,
         | 
| 36 | 
            +
                out_channels: int,
         | 
| 37 | 
            +
                norm_num_groups: int,
         | 
| 38 | 
            +
                layer_type: str,
         | 
| 39 | 
            +
                norm_type: str = 'group',
         | 
| 40 | 
            +
                affine: bool = True,
         | 
| 41 | 
            +
            ):
         | 
| 42 | 
            +
                if layer_type == 'conv':
         | 
| 43 | 
            +
                    input_layer = nn.Conv3d(
         | 
| 44 | 
            +
                        in_channels,
         | 
| 45 | 
            +
                        out_channels,
         | 
| 46 | 
            +
                        kernel_size=3,
         | 
| 47 | 
            +
                        stride=1,
         | 
| 48 | 
            +
                        padding=1,
         | 
| 49 | 
            +
                    )
         | 
| 50 | 
            +
                
         | 
| 51 | 
            +
                elif layer_type == 'pixel_shuffle':
         | 
| 52 | 
            +
                    input_layer = nn.Sequential(
         | 
| 53 | 
            +
                        nn.PixelUnshuffle(2),
         | 
| 54 | 
            +
                        nn.Conv2d(in_channels * 4, out_channels, kernel_size=1),
         | 
| 55 | 
            +
                    )
         | 
| 56 | 
            +
                else:
         | 
| 57 | 
            +
                    raise NotImplementedError(f"Not support input layer {layer_type}")
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                return input_layer
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            def get_output_layer(
         | 
| 63 | 
            +
                in_channels: int,
         | 
| 64 | 
            +
                out_channels: int,
         | 
| 65 | 
            +
                norm_num_groups: int,
         | 
| 66 | 
            +
                layer_type: str,
         | 
| 67 | 
            +
                norm_type: str = 'group',
         | 
| 68 | 
            +
                affine: bool = True,
         | 
| 69 | 
            +
            ):
         | 
| 70 | 
            +
                if layer_type == 'norm_act_conv':
         | 
| 71 | 
            +
                    output_layer = nn.Sequential(
         | 
| 72 | 
            +
                        nn.GroupNorm(num_channels=in_channels, num_groups=norm_num_groups, eps=1e-6, affine=affine),
         | 
| 73 | 
            +
                        nn.SiLU(),
         | 
| 74 | 
            +
                        nn.Conv3d(in_channels, out_channels, 3, stride=1, padding=1),
         | 
| 75 | 
            +
                    )
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                elif layer_type == 'pixel_shuffle':
         | 
| 78 | 
            +
                    output_layer = nn.Sequential(
         | 
| 79 | 
            +
                        nn.Conv2d(in_channels, out_channels * 4, kernel_size=1),
         | 
| 80 | 
            +
                        nn.PixelShuffle(2),
         | 
| 81 | 
            +
                    )
         | 
| 82 | 
            +
                
         | 
| 83 | 
            +
                else:
         | 
| 84 | 
            +
                    raise NotImplementedError(f"Not support output layer {layer_type}")
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                return output_layer
         | 
| 87 | 
            +
             | 
| 88 | 
            +
             | 
| 89 | 
            +
            def get_down_block(
         | 
| 90 | 
            +
                down_block_type: str,
         | 
| 91 | 
            +
                num_layers: int,
         | 
| 92 | 
            +
                in_channels: int,
         | 
| 93 | 
            +
                out_channels: int = None,
         | 
| 94 | 
            +
                temb_channels: int = None,
         | 
| 95 | 
            +
                add_spatial_downsample: bool = None,
         | 
| 96 | 
            +
                add_temporal_downsample: bool = None,
         | 
| 97 | 
            +
                resnet_eps: float = 1e-6,
         | 
| 98 | 
            +
                resnet_act_fn: str = 'silu',
         | 
| 99 | 
            +
                resnet_groups: Optional[int] = None,
         | 
| 100 | 
            +
                downsample_padding: Optional[int] = None,
         | 
| 101 | 
            +
                resnet_time_scale_shift: str = "default",
         | 
| 102 | 
            +
                attention_head_dim: Optional[int] = None,
         | 
| 103 | 
            +
                dropout: float = 0.0,
         | 
| 104 | 
            +
                norm_affline: bool = True,
         | 
| 105 | 
            +
                norm_layer: str = 'layer',
         | 
| 106 | 
            +
            ):
         | 
| 107 | 
            +
                
         | 
| 108 | 
            +
                if down_block_type == "DownEncoderBlock2D":
         | 
| 109 | 
            +
                    return DownEncoderBlock2D(
         | 
| 110 | 
            +
                        num_layers=num_layers,
         | 
| 111 | 
            +
                        in_channels=in_channels,
         | 
| 112 | 
            +
                        out_channels=out_channels,
         | 
| 113 | 
            +
                        dropout=dropout,
         | 
| 114 | 
            +
                        add_spatial_downsample=add_spatial_downsample,
         | 
| 115 | 
            +
                        add_temporal_downsample=add_temporal_downsample,
         | 
| 116 | 
            +
                        resnet_eps=resnet_eps,
         | 
| 117 | 
            +
                        resnet_act_fn=resnet_act_fn,
         | 
| 118 | 
            +
                        resnet_groups=resnet_groups,
         | 
| 119 | 
            +
                        downsample_padding=downsample_padding,
         | 
| 120 | 
            +
                        resnet_time_scale_shift=resnet_time_scale_shift,
         | 
| 121 | 
            +
                    )
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                elif down_block_type == "DownEncoderBlockCausal3D":
         | 
| 124 | 
            +
                    return DownEncoderBlockCausal3D(
         | 
| 125 | 
            +
                        num_layers=num_layers,
         | 
| 126 | 
            +
                        in_channels=in_channels,
         | 
| 127 | 
            +
                        out_channels=out_channels,
         | 
| 128 | 
            +
                        dropout=dropout,
         | 
| 129 | 
            +
                        add_spatial_downsample=add_spatial_downsample,
         | 
| 130 | 
            +
                        add_temporal_downsample=add_temporal_downsample,
         | 
| 131 | 
            +
                        resnet_eps=resnet_eps,
         | 
| 132 | 
            +
                        resnet_act_fn=resnet_act_fn,
         | 
| 133 | 
            +
                        resnet_groups=resnet_groups,
         | 
| 134 | 
            +
                        downsample_padding=downsample_padding,
         | 
| 135 | 
            +
                        resnet_time_scale_shift=resnet_time_scale_shift,
         | 
| 136 | 
            +
                    )
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                raise ValueError(f"{down_block_type} does not exist.")
         | 
| 139 | 
            +
             | 
| 140 | 
            +
             | 
| 141 | 
            +
            def get_up_block(
         | 
| 142 | 
            +
                up_block_type: str,
         | 
| 143 | 
            +
                num_layers: int,
         | 
| 144 | 
            +
                in_channels: int,
         | 
| 145 | 
            +
                out_channels: int,
         | 
| 146 | 
            +
                prev_output_channel: int = None,
         | 
| 147 | 
            +
                temb_channels: int = None,
         | 
| 148 | 
            +
                add_spatial_upsample: bool = None,
         | 
| 149 | 
            +
                add_temporal_upsample: bool = None,
         | 
| 150 | 
            +
                resnet_eps: float = 1e-6,
         | 
| 151 | 
            +
                resnet_act_fn: str = 'silu',
         | 
| 152 | 
            +
                resolution_idx: Optional[int] = None,
         | 
| 153 | 
            +
                resnet_groups: Optional[int] = None,
         | 
| 154 | 
            +
                resnet_time_scale_shift: str = "default",
         | 
| 155 | 
            +
                attention_head_dim: Optional[int] = None,
         | 
| 156 | 
            +
                dropout: float = 0.0,
         | 
| 157 | 
            +
                interpolate: bool = True,
         | 
| 158 | 
            +
                norm_affline: bool = True,
         | 
| 159 | 
            +
                norm_layer: str = 'layer',
         | 
| 160 | 
            +
            ) -> nn.Module:
         | 
| 161 | 
            +
                
         | 
| 162 | 
            +
                if up_block_type == "UpDecoderBlock2D":
         | 
| 163 | 
            +
                    return UpDecoderBlock2D(
         | 
| 164 | 
            +
                        num_layers=num_layers,
         | 
| 165 | 
            +
                        in_channels=in_channels,
         | 
| 166 | 
            +
                        out_channels=out_channels,
         | 
| 167 | 
            +
                        resolution_idx=resolution_idx,
         | 
| 168 | 
            +
                        dropout=dropout,
         | 
| 169 | 
            +
                        add_spatial_upsample=add_spatial_upsample,
         | 
| 170 | 
            +
                        add_temporal_upsample=add_temporal_upsample,
         | 
| 171 | 
            +
                        resnet_eps=resnet_eps,
         | 
| 172 | 
            +
                        resnet_act_fn=resnet_act_fn,
         | 
| 173 | 
            +
                        resnet_groups=resnet_groups,
         | 
| 174 | 
            +
                        resnet_time_scale_shift=resnet_time_scale_shift,
         | 
| 175 | 
            +
                        temb_channels=temb_channels,
         | 
| 176 | 
            +
                        interpolate=interpolate,
         | 
| 177 | 
            +
                    )
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                elif up_block_type == "UpDecoderBlockCausal3D":
         | 
| 180 | 
            +
                    return UpDecoderBlockCausal3D(
         | 
| 181 | 
            +
                        num_layers=num_layers,
         | 
| 182 | 
            +
                        in_channels=in_channels,
         | 
| 183 | 
            +
                        out_channels=out_channels,
         | 
| 184 | 
            +
                        resolution_idx=resolution_idx,
         | 
| 185 | 
            +
                        dropout=dropout,
         | 
| 186 | 
            +
                        add_spatial_upsample=add_spatial_upsample,
         | 
| 187 | 
            +
                        add_temporal_upsample=add_temporal_upsample,
         | 
| 188 | 
            +
                        resnet_eps=resnet_eps,
         | 
| 189 | 
            +
                        resnet_act_fn=resnet_act_fn,
         | 
| 190 | 
            +
                        resnet_groups=resnet_groups,
         | 
| 191 | 
            +
                        resnet_time_scale_shift=resnet_time_scale_shift,
         | 
| 192 | 
            +
                        temb_channels=temb_channels,
         | 
| 193 | 
            +
                        interpolate=interpolate,
         | 
| 194 | 
            +
                    )
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                raise ValueError(f"{up_block_type} does not exist.")
         | 
| 197 | 
            +
             | 
| 198 | 
            +
             | 
| 199 | 
            +
             | 
| 200 | 
            +
            class UNetMidBlock2D(nn.Module):
         | 
| 201 | 
            +
                """
         | 
| 202 | 
            +
                A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                Args:
         | 
| 205 | 
            +
                    in_channels (`int`): The number of input channels.
         | 
| 206 | 
            +
                    temb_channels (`int`): The number of temporal embedding channels.
         | 
| 207 | 
            +
                    dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
         | 
| 208 | 
            +
                    num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
         | 
| 209 | 
            +
                    resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
         | 
| 210 | 
            +
                    resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
         | 
| 211 | 
            +
                        The type of normalization to apply to the time embeddings. This can help to improve the performance of the
         | 
| 212 | 
            +
                        model on tasks with long-range temporal dependencies.
         | 
| 213 | 
            +
                    resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
         | 
| 214 | 
            +
                    resnet_groups (`int`, *optional*, defaults to 32):
         | 
| 215 | 
            +
                        The number of groups to use in the group normalization layers of the resnet blocks.
         | 
| 216 | 
            +
                    attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
         | 
| 217 | 
            +
                    resnet_pre_norm (`bool`, *optional*, defaults to `True`):
         | 
| 218 | 
            +
                        Whether to use pre-normalization for the resnet blocks.
         | 
| 219 | 
            +
                    add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
         | 
| 220 | 
            +
                    attention_head_dim (`int`, *optional*, defaults to 1):
         | 
| 221 | 
            +
                        Dimension of a single attention head. The number of attention heads is determined based on this value and
         | 
| 222 | 
            +
                        the number of input channels.
         | 
| 223 | 
            +
                    output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                Returns:
         | 
| 226 | 
            +
                    `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
         | 
| 227 | 
            +
                    in_channels, height, width)`.
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                """
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                def __init__(
         | 
| 232 | 
            +
                    self,
         | 
| 233 | 
            +
                    in_channels: int,
         | 
| 234 | 
            +
                    temb_channels: int,
         | 
| 235 | 
            +
                    dropout: float = 0.0,
         | 
| 236 | 
            +
                    num_layers: int = 1,
         | 
| 237 | 
            +
                    resnet_eps: float = 1e-6,
         | 
| 238 | 
            +
                    resnet_time_scale_shift: str = "default",  # default, spatial
         | 
| 239 | 
            +
                    resnet_act_fn: str = "swish",
         | 
| 240 | 
            +
                    resnet_groups: int = 32,
         | 
| 241 | 
            +
                    attn_groups: Optional[int] = None,
         | 
| 242 | 
            +
                    resnet_pre_norm: bool = True,
         | 
| 243 | 
            +
                    add_attention: bool = True,
         | 
| 244 | 
            +
                    attention_head_dim: int = 1,
         | 
| 245 | 
            +
                    output_scale_factor: float = 1.0,
         | 
| 246 | 
            +
                ):
         | 
| 247 | 
            +
                    super().__init__()
         | 
| 248 | 
            +
                    resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
         | 
| 249 | 
            +
                    self.add_attention = add_attention
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                    if attn_groups is None:
         | 
| 252 | 
            +
                        attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                    # there is always at least one resnet
         | 
| 255 | 
            +
                    resnets = [
         | 
| 256 | 
            +
                        ResnetBlock2D(
         | 
| 257 | 
            +
                            in_channels=in_channels,
         | 
| 258 | 
            +
                            out_channels=in_channels,
         | 
| 259 | 
            +
                            temb_channels=temb_channels,
         | 
| 260 | 
            +
                            eps=resnet_eps,
         | 
| 261 | 
            +
                            groups=resnet_groups,
         | 
| 262 | 
            +
                            dropout=dropout,
         | 
| 263 | 
            +
                            time_embedding_norm=resnet_time_scale_shift,
         | 
| 264 | 
            +
                            non_linearity=resnet_act_fn,
         | 
| 265 | 
            +
                            output_scale_factor=output_scale_factor,
         | 
| 266 | 
            +
                            pre_norm=resnet_pre_norm,
         | 
| 267 | 
            +
                        )
         | 
| 268 | 
            +
                    ]
         | 
| 269 | 
            +
                    attentions = []
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                    if attention_head_dim is None:
         | 
| 272 | 
            +
                        logger.warn(
         | 
| 273 | 
            +
                            f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
         | 
| 274 | 
            +
                        )
         | 
| 275 | 
            +
                        attention_head_dim = in_channels
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                    for _ in range(num_layers):
         | 
| 278 | 
            +
                        if self.add_attention:
         | 
| 279 | 
            +
                            # Spatial attention
         | 
| 280 | 
            +
                            attentions.append(
         | 
| 281 | 
            +
                                Attention(
         | 
| 282 | 
            +
                                    in_channels,
         | 
| 283 | 
            +
                                    heads=in_channels // attention_head_dim,
         | 
| 284 | 
            +
                                    dim_head=attention_head_dim,
         | 
| 285 | 
            +
                                    rescale_output_factor=output_scale_factor,
         | 
| 286 | 
            +
                                    eps=resnet_eps,
         | 
| 287 | 
            +
                                    norm_num_groups=attn_groups,
         | 
| 288 | 
            +
                                    spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
         | 
| 289 | 
            +
                                    residual_connection=True,
         | 
| 290 | 
            +
                                    bias=True,
         | 
| 291 | 
            +
                                    upcast_softmax=True,
         | 
| 292 | 
            +
                                    _from_deprecated_attn_block=True,
         | 
| 293 | 
            +
                                )
         | 
| 294 | 
            +
                            )
         | 
| 295 | 
            +
                        else:
         | 
| 296 | 
            +
                            attentions.append(None)
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                        resnets.append(
         | 
| 299 | 
            +
                            ResnetBlock2D(
         | 
| 300 | 
            +
                                in_channels=in_channels,
         | 
| 301 | 
            +
                                out_channels=in_channels,
         | 
| 302 | 
            +
                                temb_channels=temb_channels,
         | 
| 303 | 
            +
                                eps=resnet_eps,
         | 
| 304 | 
            +
                                groups=resnet_groups,
         | 
| 305 | 
            +
                                dropout=dropout,
         | 
| 306 | 
            +
                                time_embedding_norm=resnet_time_scale_shift,
         | 
| 307 | 
            +
                                non_linearity=resnet_act_fn,
         | 
| 308 | 
            +
                                output_scale_factor=output_scale_factor,
         | 
| 309 | 
            +
                                pre_norm=resnet_pre_norm,
         | 
| 310 | 
            +
                            )
         | 
| 311 | 
            +
                        )
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    self.attentions = nn.ModuleList(attentions)
         | 
| 314 | 
            +
                    self.resnets = nn.ModuleList(resnets)
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
         | 
| 317 | 
            +
                    hidden_states = self.resnets[0](hidden_states, temb)
         | 
| 318 | 
            +
                    t = hidden_states.shape[2]
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    for attn, resnet in zip(self.attentions, self.resnets[1:]):
         | 
| 321 | 
            +
                        if attn is not None:
         | 
| 322 | 
            +
                            hidden_states = rearrange(hidden_states, 'b c t h w -> b t c h w')
         | 
| 323 | 
            +
                            hidden_states = rearrange(hidden_states, 'b t c h w -> (b t) c h w')
         | 
| 324 | 
            +
                            hidden_states = attn(hidden_states, temb=temb)
         | 
| 325 | 
            +
                            hidden_states = rearrange(hidden_states, '(b t) c h w -> b t c h w', t=t)
         | 
| 326 | 
            +
                            hidden_states = rearrange(hidden_states, 'b t c h w -> b c t h w')
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                        hidden_states = resnet(hidden_states, temb)
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                    return hidden_states
         | 
| 331 | 
            +
             | 
| 332 | 
            +
             | 
| 333 | 
            +
            class CausalUNetMidBlock2D(nn.Module):
         | 
| 334 | 
            +
                """
         | 
| 335 | 
            +
                A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                Args:
         | 
| 338 | 
            +
                    in_channels (`int`): The number of input channels.
         | 
| 339 | 
            +
                    temb_channels (`int`): The number of temporal embedding channels.
         | 
| 340 | 
            +
                    dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
         | 
| 341 | 
            +
                    num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
         | 
| 342 | 
            +
                    resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
         | 
| 343 | 
            +
                    resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
         | 
| 344 | 
            +
                        The type of normalization to apply to the time embeddings. This can help to improve the performance of the
         | 
| 345 | 
            +
                        model on tasks with long-range temporal dependencies.
         | 
| 346 | 
            +
                    resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
         | 
| 347 | 
            +
                    resnet_groups (`int`, *optional*, defaults to 32):
         | 
| 348 | 
            +
                        The number of groups to use in the group normalization layers of the resnet blocks.
         | 
| 349 | 
            +
                    attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
         | 
| 350 | 
            +
                    resnet_pre_norm (`bool`, *optional*, defaults to `True`):
         | 
| 351 | 
            +
                        Whether to use pre-normalization for the resnet blocks.
         | 
| 352 | 
            +
                    add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
         | 
| 353 | 
            +
                    attention_head_dim (`int`, *optional*, defaults to 1):
         | 
| 354 | 
            +
                        Dimension of a single attention head. The number of attention heads is determined based on this value and
         | 
| 355 | 
            +
                        the number of input channels.
         | 
| 356 | 
            +
                    output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                Returns:
         | 
| 359 | 
            +
                    `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
         | 
| 360 | 
            +
                    in_channels, height, width)`.
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                """
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                def __init__(
         | 
| 365 | 
            +
                    self,
         | 
| 366 | 
            +
                    in_channels: int,
         | 
| 367 | 
            +
                    temb_channels: int,
         | 
| 368 | 
            +
                    dropout: float = 0.0,
         | 
| 369 | 
            +
                    num_layers: int = 1,
         | 
| 370 | 
            +
                    resnet_eps: float = 1e-6,
         | 
| 371 | 
            +
                    resnet_time_scale_shift: str = "default",  # default, spatial
         | 
| 372 | 
            +
                    resnet_act_fn: str = "swish",
         | 
| 373 | 
            +
                    resnet_groups: int = 32,
         | 
| 374 | 
            +
                    attn_groups: Optional[int] = None,
         | 
| 375 | 
            +
                    resnet_pre_norm: bool = True,
         | 
| 376 | 
            +
                    add_attention: bool = True,
         | 
| 377 | 
            +
                    attention_head_dim: int = 1,
         | 
| 378 | 
            +
                    output_scale_factor: float = 1.0,
         | 
| 379 | 
            +
                ):
         | 
| 380 | 
            +
                    super().__init__()
         | 
| 381 | 
            +
                    resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
         | 
| 382 | 
            +
                    self.add_attention = add_attention
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                    if attn_groups is None:
         | 
| 385 | 
            +
                        attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                    # there is always at least one resnet
         | 
| 388 | 
            +
                    resnets = [
         | 
| 389 | 
            +
                        CausalResnetBlock3D(
         | 
| 390 | 
            +
                            in_channels=in_channels,
         | 
| 391 | 
            +
                            out_channels=in_channels,
         | 
| 392 | 
            +
                            temb_channels=temb_channels,
         | 
| 393 | 
            +
                            eps=resnet_eps,
         | 
| 394 | 
            +
                            groups=resnet_groups,
         | 
| 395 | 
            +
                            dropout=dropout,
         | 
| 396 | 
            +
                            time_embedding_norm=resnet_time_scale_shift,
         | 
| 397 | 
            +
                            non_linearity=resnet_act_fn,
         | 
| 398 | 
            +
                            output_scale_factor=output_scale_factor,
         | 
| 399 | 
            +
                            pre_norm=resnet_pre_norm,
         | 
| 400 | 
            +
                        )
         | 
| 401 | 
            +
                    ]
         | 
| 402 | 
            +
                    attentions = []
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                    if attention_head_dim is None:
         | 
| 405 | 
            +
                        logger.warn(
         | 
| 406 | 
            +
                            f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
         | 
| 407 | 
            +
                        )
         | 
| 408 | 
            +
                        attention_head_dim = in_channels
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                    for _ in range(num_layers):
         | 
| 411 | 
            +
                        if self.add_attention:
         | 
| 412 | 
            +
                            # Spatial attention
         | 
| 413 | 
            +
                            attentions.append(
         | 
| 414 | 
            +
                                Attention(
         | 
| 415 | 
            +
                                    in_channels,
         | 
| 416 | 
            +
                                    heads=in_channels // attention_head_dim,
         | 
| 417 | 
            +
                                    dim_head=attention_head_dim,
         | 
| 418 | 
            +
                                    rescale_output_factor=output_scale_factor,
         | 
| 419 | 
            +
                                    eps=resnet_eps,
         | 
| 420 | 
            +
                                    norm_num_groups=attn_groups,
         | 
| 421 | 
            +
                                    spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
         | 
| 422 | 
            +
                                    residual_connection=True,
         | 
| 423 | 
            +
                                    bias=True,
         | 
| 424 | 
            +
                                    upcast_softmax=True,
         | 
| 425 | 
            +
                                    _from_deprecated_attn_block=True,
         | 
| 426 | 
            +
                                )
         | 
| 427 | 
            +
                            )
         | 
| 428 | 
            +
                        else:
         | 
| 429 | 
            +
                            attentions.append(None)
         | 
| 430 | 
            +
             | 
| 431 | 
            +
                        resnets.append(
         | 
| 432 | 
            +
                            CausalResnetBlock3D(
         | 
| 433 | 
            +
                                in_channels=in_channels,
         | 
| 434 | 
            +
                                out_channels=in_channels,
         | 
| 435 | 
            +
                                temb_channels=temb_channels,
         | 
| 436 | 
            +
                                eps=resnet_eps,
         | 
| 437 | 
            +
                                groups=resnet_groups,
         | 
| 438 | 
            +
                                dropout=dropout,
         | 
| 439 | 
            +
                                time_embedding_norm=resnet_time_scale_shift,
         | 
| 440 | 
            +
                                non_linearity=resnet_act_fn,
         | 
| 441 | 
            +
                                output_scale_factor=output_scale_factor,
         | 
| 442 | 
            +
                                pre_norm=resnet_pre_norm,
         | 
| 443 | 
            +
                            )
         | 
| 444 | 
            +
                        )
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                    self.attentions = nn.ModuleList(attentions)
         | 
| 447 | 
            +
                    self.resnets = nn.ModuleList(resnets)
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None,
         | 
| 450 | 
            +
                        is_init_image=True, temporal_chunk=False) -> torch.FloatTensor:
         | 
| 451 | 
            +
                    hidden_states = self.resnets[0](hidden_states, temb, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
         | 
| 452 | 
            +
                    t = hidden_states.shape[2]
         | 
| 453 | 
            +
             | 
| 454 | 
            +
                    for attn, resnet in zip(self.attentions, self.resnets[1:]):
         | 
| 455 | 
            +
                        if attn is not None:
         | 
| 456 | 
            +
                            hidden_states = rearrange(hidden_states, 'b c t h w -> b t c h w')
         | 
| 457 | 
            +
                            hidden_states = rearrange(hidden_states, 'b t c h w -> (b t) c h w')
         | 
| 458 | 
            +
                            hidden_states = attn(hidden_states, temb=temb)
         | 
| 459 | 
            +
                            hidden_states = rearrange(hidden_states, '(b t) c h w -> b t c h w', t=t)
         | 
| 460 | 
            +
                            hidden_states = rearrange(hidden_states, 'b t c h w -> b c t h w')
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                        hidden_states = resnet(hidden_states, temb, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                    return hidden_states
         | 
| 465 | 
            +
             | 
| 466 | 
            +
             | 
| 467 | 
            +
            class DownEncoderBlockCausal3D(nn.Module):
         | 
| 468 | 
            +
                def __init__(
         | 
| 469 | 
            +
                    self,
         | 
| 470 | 
            +
                    in_channels: int,
         | 
| 471 | 
            +
                    out_channels: int,
         | 
| 472 | 
            +
                    dropout: float = 0.0,
         | 
| 473 | 
            +
                    num_layers: int = 1,
         | 
| 474 | 
            +
                    resnet_eps: float = 1e-6,
         | 
| 475 | 
            +
                    resnet_time_scale_shift: str = "default",
         | 
| 476 | 
            +
                    resnet_act_fn: str = "swish",
         | 
| 477 | 
            +
                    resnet_groups: int = 32,
         | 
| 478 | 
            +
                    resnet_pre_norm: bool = True,
         | 
| 479 | 
            +
                    output_scale_factor: float = 1.0,
         | 
| 480 | 
            +
                    add_spatial_downsample: bool = True,
         | 
| 481 | 
            +
                    add_temporal_downsample: bool = False,
         | 
| 482 | 
            +
                    downsample_padding: int = 1,
         | 
| 483 | 
            +
                ):
         | 
| 484 | 
            +
                    super().__init__()
         | 
| 485 | 
            +
                    resnets = []
         | 
| 486 | 
            +
             | 
| 487 | 
            +
                    for i in range(num_layers):
         | 
| 488 | 
            +
                        in_channels = in_channels if i == 0 else out_channels
         | 
| 489 | 
            +
                        resnets.append(
         | 
| 490 | 
            +
                            CausalResnetBlock3D(
         | 
| 491 | 
            +
                                in_channels=in_channels,
         | 
| 492 | 
            +
                                out_channels=out_channels,
         | 
| 493 | 
            +
                                temb_channels=None,
         | 
| 494 | 
            +
                                eps=resnet_eps,
         | 
| 495 | 
            +
                                groups=resnet_groups,
         | 
| 496 | 
            +
                                dropout=dropout,
         | 
| 497 | 
            +
                                time_embedding_norm=resnet_time_scale_shift,
         | 
| 498 | 
            +
                                non_linearity=resnet_act_fn,
         | 
| 499 | 
            +
                                output_scale_factor=output_scale_factor,
         | 
| 500 | 
            +
                                pre_norm=resnet_pre_norm,
         | 
| 501 | 
            +
                            )
         | 
| 502 | 
            +
                        )
         | 
| 503 | 
            +
             | 
| 504 | 
            +
                    self.resnets = nn.ModuleList(resnets)
         | 
| 505 | 
            +
             | 
| 506 | 
            +
                    if add_spatial_downsample:
         | 
| 507 | 
            +
                        self.downsamplers = nn.ModuleList(
         | 
| 508 | 
            +
                            [
         | 
| 509 | 
            +
                                CausalDownsample2x(
         | 
| 510 | 
            +
                                    out_channels, use_conv=True, out_channels=out_channels,
         | 
| 511 | 
            +
                                )
         | 
| 512 | 
            +
                            ]
         | 
| 513 | 
            +
                        )
         | 
| 514 | 
            +
                    else:
         | 
| 515 | 
            +
                        self.downsamplers = None
         | 
| 516 | 
            +
             | 
| 517 | 
            +
                    if add_temporal_downsample:
         | 
| 518 | 
            +
                        self.temporal_downsamplers = nn.ModuleList(
         | 
| 519 | 
            +
                            [
         | 
| 520 | 
            +
                                CausalTemporalDownsample2x(
         | 
| 521 | 
            +
                                    out_channels, use_conv=True, out_channels=out_channels,
         | 
| 522 | 
            +
                                )
         | 
| 523 | 
            +
                            ]
         | 
| 524 | 
            +
                        )
         | 
| 525 | 
            +
                    else:
         | 
| 526 | 
            +
                        self.temporal_downsamplers = None
         | 
| 527 | 
            +
             | 
| 528 | 
            +
                def forward(self, hidden_states: torch.FloatTensor, is_init_image=True, temporal_chunk=False) -> torch.FloatTensor:
         | 
| 529 | 
            +
                    for resnet in self.resnets:
         | 
| 530 | 
            +
                        hidden_states = resnet(hidden_states, temb=None, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
         | 
| 531 | 
            +
             | 
| 532 | 
            +
                    if self.downsamplers is not None:
         | 
| 533 | 
            +
                        for downsampler in self.downsamplers:
         | 
| 534 | 
            +
                            hidden_states = downsampler(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
         | 
| 535 | 
            +
             | 
| 536 | 
            +
                    if self.temporal_downsamplers is not None:
         | 
| 537 | 
            +
                        for temporal_downsampler in self.temporal_downsamplers:
         | 
| 538 | 
            +
                            hidden_states = temporal_downsampler(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
         | 
| 539 | 
            +
             | 
| 540 | 
            +
                    return hidden_states
         | 
| 541 | 
            +
                
         | 
| 542 | 
            +
             | 
| 543 | 
            +
            class DownEncoderBlock2D(nn.Module):
         | 
| 544 | 
            +
                def __init__(
         | 
| 545 | 
            +
                    self,
         | 
| 546 | 
            +
                    in_channels: int,
         | 
| 547 | 
            +
                    out_channels: int,
         | 
| 548 | 
            +
                    dropout: float = 0.0,
         | 
| 549 | 
            +
                    num_layers: int = 1,
         | 
| 550 | 
            +
                    resnet_eps: float = 1e-6,
         | 
| 551 | 
            +
                    resnet_time_scale_shift: str = "default",
         | 
| 552 | 
            +
                    resnet_act_fn: str = "swish",
         | 
| 553 | 
            +
                    resnet_groups: int = 32,
         | 
| 554 | 
            +
                    resnet_pre_norm: bool = True,
         | 
| 555 | 
            +
                    output_scale_factor: float = 1.0,
         | 
| 556 | 
            +
                    add_spatial_downsample: bool = True,
         | 
| 557 | 
            +
                    add_temporal_downsample: bool = False,
         | 
| 558 | 
            +
                    downsample_padding: int = 1,
         | 
| 559 | 
            +
                ):
         | 
| 560 | 
            +
                    super().__init__()
         | 
| 561 | 
            +
                    resnets = []
         | 
| 562 | 
            +
             | 
| 563 | 
            +
                    for i in range(num_layers):
         | 
| 564 | 
            +
                        in_channels = in_channels if i == 0 else out_channels
         | 
| 565 | 
            +
                        resnets.append(
         | 
| 566 | 
            +
                            ResnetBlock2D(
         | 
| 567 | 
            +
                                in_channels=in_channels,
         | 
| 568 | 
            +
                                out_channels=out_channels,
         | 
| 569 | 
            +
                                temb_channels=None,
         | 
| 570 | 
            +
                                eps=resnet_eps,
         | 
| 571 | 
            +
                                groups=resnet_groups,
         | 
| 572 | 
            +
                                dropout=dropout,
         | 
| 573 | 
            +
                                time_embedding_norm=resnet_time_scale_shift,
         | 
| 574 | 
            +
                                non_linearity=resnet_act_fn,
         | 
| 575 | 
            +
                                output_scale_factor=output_scale_factor,
         | 
| 576 | 
            +
                                pre_norm=resnet_pre_norm,
         | 
| 577 | 
            +
                            )
         | 
| 578 | 
            +
                        )
         | 
| 579 | 
            +
             | 
| 580 | 
            +
                    self.resnets = nn.ModuleList(resnets)
         | 
| 581 | 
            +
             | 
| 582 | 
            +
                    if add_spatial_downsample:
         | 
| 583 | 
            +
                        self.downsamplers = nn.ModuleList(
         | 
| 584 | 
            +
                            [
         | 
| 585 | 
            +
                                Downsample2D(
         | 
| 586 | 
            +
                                    out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
         | 
| 587 | 
            +
                                )
         | 
| 588 | 
            +
                            ]
         | 
| 589 | 
            +
                        )
         | 
| 590 | 
            +
                    else:
         | 
| 591 | 
            +
                        self.downsamplers = None
         | 
| 592 | 
            +
             | 
| 593 | 
            +
                    if add_temporal_downsample:
         | 
| 594 | 
            +
                        self.temporal_downsamplers = nn.ModuleList(
         | 
| 595 | 
            +
                            [
         | 
| 596 | 
            +
                                TemporalDownsample2x(
         | 
| 597 | 
            +
                                    out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding,
         | 
| 598 | 
            +
                                )
         | 
| 599 | 
            +
                            ]
         | 
| 600 | 
            +
                        )
         | 
| 601 | 
            +
                    else:
         | 
| 602 | 
            +
                        self.temporal_downsamplers = None
         | 
| 603 | 
            +
             | 
| 604 | 
            +
                def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
         | 
| 605 | 
            +
                    for resnet in self.resnets:
         | 
| 606 | 
            +
                        hidden_states = resnet(hidden_states, temb=None)
         | 
| 607 | 
            +
             | 
| 608 | 
            +
                    if self.downsamplers is not None:
         | 
| 609 | 
            +
                        for downsampler in self.downsamplers:
         | 
| 610 | 
            +
                            hidden_states = downsampler(hidden_states)
         | 
| 611 | 
            +
             | 
| 612 | 
            +
                    if self.temporal_downsamplers is not None:
         | 
| 613 | 
            +
                        for temporal_downsampler in self.temporal_downsamplers:
         | 
| 614 | 
            +
                            hidden_states = temporal_downsampler(hidden_states)
         | 
| 615 | 
            +
             | 
| 616 | 
            +
                    return hidden_states
         | 
| 617 | 
            +
             | 
| 618 | 
            +
             | 
| 619 | 
            +
            class UpDecoderBlock2D(nn.Module):
         | 
| 620 | 
            +
                def __init__(
         | 
| 621 | 
            +
                    self,
         | 
| 622 | 
            +
                    in_channels: int,
         | 
| 623 | 
            +
                    out_channels: int,
         | 
| 624 | 
            +
                    resolution_idx: Optional[int] = None,
         | 
| 625 | 
            +
                    dropout: float = 0.0,
         | 
| 626 | 
            +
                    num_layers: int = 1,
         | 
| 627 | 
            +
                    resnet_eps: float = 1e-6,
         | 
| 628 | 
            +
                    resnet_time_scale_shift: str = "default",  # default, spatial
         | 
| 629 | 
            +
                    resnet_act_fn: str = "swish",
         | 
| 630 | 
            +
                    resnet_groups: int = 32,
         | 
| 631 | 
            +
                    resnet_pre_norm: bool = True,
         | 
| 632 | 
            +
                    output_scale_factor: float = 1.0,
         | 
| 633 | 
            +
                    add_spatial_upsample: bool = True,
         | 
| 634 | 
            +
                    add_temporal_upsample: bool = False,
         | 
| 635 | 
            +
                    temb_channels: Optional[int] = None,
         | 
| 636 | 
            +
                    interpolate: bool = True,
         | 
| 637 | 
            +
                ):
         | 
| 638 | 
            +
                    super().__init__()
         | 
| 639 | 
            +
                    resnets = []
         | 
| 640 | 
            +
             | 
| 641 | 
            +
                    for i in range(num_layers):
         | 
| 642 | 
            +
                        input_channels = in_channels if i == 0 else out_channels
         | 
| 643 | 
            +
             | 
| 644 | 
            +
                        resnets.append(
         | 
| 645 | 
            +
                            ResnetBlock2D(
         | 
| 646 | 
            +
                                in_channels=input_channels,
         | 
| 647 | 
            +
                                out_channels=out_channels,
         | 
| 648 | 
            +
                                temb_channels=temb_channels,
         | 
| 649 | 
            +
                                eps=resnet_eps,
         | 
| 650 | 
            +
                                groups=resnet_groups,
         | 
| 651 | 
            +
                                dropout=dropout,
         | 
| 652 | 
            +
                                time_embedding_norm=resnet_time_scale_shift,
         | 
| 653 | 
            +
                                non_linearity=resnet_act_fn,
         | 
| 654 | 
            +
                                output_scale_factor=output_scale_factor,
         | 
| 655 | 
            +
                                pre_norm=resnet_pre_norm,
         | 
| 656 | 
            +
                            )
         | 
| 657 | 
            +
                        )
         | 
| 658 | 
            +
             | 
| 659 | 
            +
                    self.resnets = nn.ModuleList(resnets)
         | 
| 660 | 
            +
             | 
| 661 | 
            +
                    if add_spatial_upsample:
         | 
| 662 | 
            +
                        self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels, interpolate=interpolate)])
         | 
| 663 | 
            +
                    else:
         | 
| 664 | 
            +
                        self.upsamplers = None
         | 
| 665 | 
            +
             | 
| 666 | 
            +
                    if add_temporal_upsample:
         | 
| 667 | 
            +
                        self.temporal_upsamplers = nn.ModuleList([TemporalUpsample2x(out_channels, use_conv=True, out_channels=out_channels, interpolate=interpolate)])
         | 
| 668 | 
            +
                    else:
         | 
| 669 | 
            +
                        self.temporal_upsamplers = None
         | 
| 670 | 
            +
             | 
| 671 | 
            +
                    self.resolution_idx = resolution_idx
         | 
| 672 | 
            +
             | 
| 673 | 
            +
                def forward(
         | 
| 674 | 
            +
                    self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, is_image: bool = False,
         | 
| 675 | 
            +
                ) -> torch.FloatTensor:
         | 
| 676 | 
            +
                    for resnet in self.resnets:
         | 
| 677 | 
            +
                        hidden_states = resnet(hidden_states, temb=temb, scale=scale)
         | 
| 678 | 
            +
             | 
| 679 | 
            +
                    if self.upsamplers is not None:
         | 
| 680 | 
            +
                        for upsampler in self.upsamplers:
         | 
| 681 | 
            +
                            hidden_states = upsampler(hidden_states)
         | 
| 682 | 
            +
                    
         | 
| 683 | 
            +
                    if self.temporal_upsamplers is not None:
         | 
| 684 | 
            +
                        for temporal_upsampler in self.temporal_upsamplers:
         | 
| 685 | 
            +
                            hidden_states = temporal_upsampler(hidden_states, is_image=is_image)
         | 
| 686 | 
            +
             | 
| 687 | 
            +
                    return hidden_states
         | 
| 688 | 
            +
             | 
| 689 | 
            +
             | 
| 690 | 
            +
            class UpDecoderBlockCausal3D(nn.Module):
         | 
| 691 | 
            +
                def __init__(
         | 
| 692 | 
            +
                    self,
         | 
| 693 | 
            +
                    in_channels: int,
         | 
| 694 | 
            +
                    out_channels: int,
         | 
| 695 | 
            +
                    resolution_idx: Optional[int] = None,
         | 
| 696 | 
            +
                    dropout: float = 0.0,
         | 
| 697 | 
            +
                    num_layers: int = 1,
         | 
| 698 | 
            +
                    resnet_eps: float = 1e-6,
         | 
| 699 | 
            +
                    resnet_time_scale_shift: str = "default",  # default, spatial
         | 
| 700 | 
            +
                    resnet_act_fn: str = "swish",
         | 
| 701 | 
            +
                    resnet_groups: int = 32,
         | 
| 702 | 
            +
                    resnet_pre_norm: bool = True,
         | 
| 703 | 
            +
                    output_scale_factor: float = 1.0,
         | 
| 704 | 
            +
                    add_spatial_upsample: bool = True,
         | 
| 705 | 
            +
                    add_temporal_upsample: bool = False,
         | 
| 706 | 
            +
                    temb_channels: Optional[int] = None,
         | 
| 707 | 
            +
                    interpolate: bool = True,
         | 
| 708 | 
            +
                ):
         | 
| 709 | 
            +
                    super().__init__()
         | 
| 710 | 
            +
                    resnets = []
         | 
| 711 | 
            +
             | 
| 712 | 
            +
                    for i in range(num_layers):
         | 
| 713 | 
            +
                        input_channels = in_channels if i == 0 else out_channels
         | 
| 714 | 
            +
             | 
| 715 | 
            +
                        resnets.append(
         | 
| 716 | 
            +
                            CausalResnetBlock3D(
         | 
| 717 | 
            +
                                in_channels=input_channels,
         | 
| 718 | 
            +
                                out_channels=out_channels,
         | 
| 719 | 
            +
                                temb_channels=temb_channels,
         | 
| 720 | 
            +
                                eps=resnet_eps,
         | 
| 721 | 
            +
                                groups=resnet_groups,
         | 
| 722 | 
            +
                                dropout=dropout,
         | 
| 723 | 
            +
                                time_embedding_norm=resnet_time_scale_shift,
         | 
| 724 | 
            +
                                non_linearity=resnet_act_fn,
         | 
| 725 | 
            +
                                output_scale_factor=output_scale_factor,
         | 
| 726 | 
            +
                                pre_norm=resnet_pre_norm,
         | 
| 727 | 
            +
                            )
         | 
| 728 | 
            +
                        )
         | 
| 729 | 
            +
             | 
| 730 | 
            +
                    self.resnets = nn.ModuleList(resnets)
         | 
| 731 | 
            +
             | 
| 732 | 
            +
                    if add_spatial_upsample:
         | 
| 733 | 
            +
                        self.upsamplers = nn.ModuleList([CausalUpsample2x(out_channels, use_conv=True, out_channels=out_channels, interpolate=interpolate)])
         | 
| 734 | 
            +
                    else:
         | 
| 735 | 
            +
                        self.upsamplers = None
         | 
| 736 | 
            +
             | 
| 737 | 
            +
                    if add_temporal_upsample:
         | 
| 738 | 
            +
                        self.temporal_upsamplers = nn.ModuleList([CausalTemporalUpsample2x(out_channels, use_conv=True, out_channels=out_channels, interpolate=interpolate)])
         | 
| 739 | 
            +
                    else:
         | 
| 740 | 
            +
                        self.temporal_upsamplers = None
         | 
| 741 | 
            +
             | 
| 742 | 
            +
                    self.resolution_idx = resolution_idx
         | 
| 743 | 
            +
             | 
| 744 | 
            +
                def forward(
         | 
| 745 | 
            +
                    self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None,
         | 
| 746 | 
            +
                    is_init_image=True, temporal_chunk=False,
         | 
| 747 | 
            +
                ) -> torch.FloatTensor:
         | 
| 748 | 
            +
                    for resnet in self.resnets:
         | 
| 749 | 
            +
                        hidden_states = resnet(hidden_states, temb=temb, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
         | 
| 750 | 
            +
             | 
| 751 | 
            +
                    if self.upsamplers is not None:
         | 
| 752 | 
            +
                        for upsampler in self.upsamplers:
         | 
| 753 | 
            +
                            hidden_states = upsampler(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
         | 
| 754 | 
            +
                    
         | 
| 755 | 
            +
                    if self.temporal_upsamplers is not None:
         | 
| 756 | 
            +
                        for temporal_upsampler in self.temporal_upsamplers:
         | 
| 757 | 
            +
                            hidden_states = temporal_upsampler(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
         | 
| 758 | 
            +
             | 
| 759 | 
            +
                    return hidden_states
         | 
| 760 | 
            +
             | 
    	
        video_vae/modeling_causal_conv.py
    ADDED
    
    | @@ -0,0 +1,139 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Tuple, Union
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.nn as nn
         | 
| 4 | 
            +
            from torch.utils.checkpoint import checkpoint
         | 
| 5 | 
            +
            import torch.nn.functional as F
         | 
| 6 | 
            +
            from collections import deque
         | 
| 7 | 
            +
            from einops import rearrange
         | 
| 8 | 
            +
            from timm.models.layers import trunc_normal_
         | 
| 9 | 
            +
            from IPython import embed
         | 
| 10 | 
            +
            from torch import Tensor
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from utils import (
         | 
| 13 | 
            +
                is_context_parallel_initialized,
         | 
| 14 | 
            +
                get_context_parallel_group,
         | 
| 15 | 
            +
                get_context_parallel_world_size,
         | 
| 16 | 
            +
                get_context_parallel_rank,
         | 
| 17 | 
            +
                get_context_parallel_group_rank,
         | 
| 18 | 
            +
            )
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from .context_parallel_ops import (
         | 
| 21 | 
            +
                conv_scatter_to_context_parallel_region,
         | 
| 22 | 
            +
                conv_gather_from_context_parallel_region,
         | 
| 23 | 
            +
                cp_pass_from_previous_rank,
         | 
| 24 | 
            +
            )
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def divisible_by(num, den):
         | 
| 28 | 
            +
                return (num % den) == 0
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            def cast_tuple(t, length = 1):
         | 
| 31 | 
            +
                return t if isinstance(t, tuple) else ((t,) * length)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            def is_odd(n):
         | 
| 34 | 
            +
                return not divisible_by(n, 2)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            class CausalGroupNorm(nn.GroupNorm):
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                def forward(self, x: Tensor) -> Tensor:
         | 
| 40 | 
            +
                    t = x.shape[2]
         | 
| 41 | 
            +
                    x = rearrange(x, 'b c t h w -> (b t) c h w')
         | 
| 42 | 
            +
                    x = super().forward(x)
         | 
| 43 | 
            +
                    x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
         | 
| 44 | 
            +
                    return x
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            class CausalConv3d(nn.Module):
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                def __init__(
         | 
| 50 | 
            +
                        self,
         | 
| 51 | 
            +
                        in_channels,
         | 
| 52 | 
            +
                        out_channels,
         | 
| 53 | 
            +
                        kernel_size: Union[int, Tuple[int, int, int]],
         | 
| 54 | 
            +
                        stride: Union[int, Tuple[int, int, int]] = 1,
         | 
| 55 | 
            +
                        pad_mode: str ='constant',
         | 
| 56 | 
            +
                        **kwargs
         | 
| 57 | 
            +
                ):
         | 
| 58 | 
            +
                    super().__init__()
         | 
| 59 | 
            +
                    if isinstance(kernel_size, int):
         | 
| 60 | 
            +
                        kernel_size = cast_tuple(kernel_size, 3)
         | 
| 61 | 
            +
                
         | 
| 62 | 
            +
                    time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
         | 
| 63 | 
            +
                    self.time_kernel_size = time_kernel_size
         | 
| 64 | 
            +
                    assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
         | 
| 65 | 
            +
                    dilation = kwargs.pop('dilation', 1)
         | 
| 66 | 
            +
                    self.pad_mode = pad_mode
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    if isinstance(stride, int):
         | 
| 69 | 
            +
                        stride = (stride, 1, 1)
         | 
| 70 | 
            +
                
         | 
| 71 | 
            +
                    time_pad = dilation * (time_kernel_size - 1)
         | 
| 72 | 
            +
                    height_pad = height_kernel_size // 2
         | 
| 73 | 
            +
                    width_pad = width_kernel_size // 2
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    self.temporal_stride = stride[0]
         | 
| 76 | 
            +
                    self.time_pad = time_pad
         | 
| 77 | 
            +
                    self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
         | 
| 78 | 
            +
                    self.time_uncausal_padding = (width_pad, width_pad, height_pad, height_pad, 0, 0)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=0, dilation=dilation, **kwargs)
         | 
| 81 | 
            +
                    self.cache_front_feat = deque()
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                def _clear_context_parallel_cache(self):
         | 
| 84 | 
            +
                    del self.cache_front_feat
         | 
| 85 | 
            +
                    self.cache_front_feat = deque()
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                def _init_weights(self, m):
         | 
| 88 | 
            +
                    if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)):
         | 
| 89 | 
            +
                        trunc_normal_(m.weight, std=.02)
         | 
| 90 | 
            +
                        if m.bias is not None:
         | 
| 91 | 
            +
                            nn.init.constant_(m.bias, 0)
         | 
| 92 | 
            +
                    elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
         | 
| 93 | 
            +
                        nn.init.constant_(m.bias, 0)
         | 
| 94 | 
            +
                        nn.init.constant_(m.weight, 1.0)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                def context_parallel_forward(self, x):
         | 
| 97 | 
            +
                    x = cp_pass_from_previous_rank(x, dim=2, kernel_size=self.time_kernel_size)
         | 
| 98 | 
            +
                    
         | 
| 99 | 
            +
                    x = F.pad(x, self.time_uncausal_padding, mode='constant')
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    cp_rank = get_context_parallel_rank()
         | 
| 102 | 
            +
                    if cp_rank != 0:
         | 
| 103 | 
            +
                        if self.temporal_stride == 2 and self.time_kernel_size == 3:
         | 
| 104 | 
            +
                            x = x[:,:,1:]
         | 
| 105 | 
            +
                
         | 
| 106 | 
            +
                    x = self.conv(x)
         | 
| 107 | 
            +
                    return x
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                def forward(self, x, is_init_image=True, temporal_chunk=False):
         | 
| 110 | 
            +
                    # temporal_chunk: whether to use the temporal chunk
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    if is_context_parallel_initialized():
         | 
| 113 | 
            +
                        return self.context_parallel_forward(x)
         | 
| 114 | 
            +
                    
         | 
| 115 | 
            +
                    pad_mode = self.pad_mode if self.time_pad < x.shape[2] else 'constant'
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    if not temporal_chunk:
         | 
| 118 | 
            +
                        x = F.pad(x, self.time_causal_padding, mode=pad_mode)
         | 
| 119 | 
            +
                    else:
         | 
| 120 | 
            +
                        assert not self.training, "The feature cache should not be used in training"
         | 
| 121 | 
            +
                        if is_init_image:
         | 
| 122 | 
            +
                            # Encode the first chunk
         | 
| 123 | 
            +
                            x = F.pad(x, self.time_causal_padding, mode=pad_mode)
         | 
| 124 | 
            +
                            self._clear_context_parallel_cache()
         | 
| 125 | 
            +
                            self.cache_front_feat.append(x[:, :, -2:].clone().detach())
         | 
| 126 | 
            +
                        else:
         | 
| 127 | 
            +
                            x = F.pad(x, self.time_uncausal_padding, mode=pad_mode)
         | 
| 128 | 
            +
                            video_front_context = self.cache_front_feat.pop()
         | 
| 129 | 
            +
                            self._clear_context_parallel_cache()
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                            if self.temporal_stride == 1 and self.time_kernel_size == 3:
         | 
| 132 | 
            +
                                x = torch.cat([video_front_context, x], dim=2)
         | 
| 133 | 
            +
                            elif self.temporal_stride == 2 and self.time_kernel_size == 3:
         | 
| 134 | 
            +
                                x = torch.cat([video_front_context[:,:,-1:], x], dim=2)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                            self.cache_front_feat.append(x[:, :, -2:].clone().detach())
         | 
| 137 | 
            +
                    
         | 
| 138 | 
            +
                    x = self.conv(x)
         | 
| 139 | 
            +
                    return x
         | 
    	
        video_vae/modeling_causal_vae.py
    ADDED
    
    | @@ -0,0 +1,625 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Dict, Optional, Tuple, Union
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.nn as nn
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from diffusers.configuration_utils import ConfigMixin, register_to_config
         | 
| 6 | 
            +
            from diffusers.models.attention_processor import (
         | 
| 7 | 
            +
                ADDED_KV_ATTENTION_PROCESSORS,
         | 
| 8 | 
            +
                CROSS_ATTENTION_PROCESSORS,
         | 
| 9 | 
            +
                Attention,
         | 
| 10 | 
            +
                AttentionProcessor,
         | 
| 11 | 
            +
                AttnAddedKVProcessor,
         | 
| 12 | 
            +
                AttnProcessor,
         | 
| 13 | 
            +
            )
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from diffusers.models.modeling_outputs import AutoencoderKLOutput
         | 
| 16 | 
            +
            from diffusers.models.modeling_utils import ModelMixin
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from timm.models.layers import drop_path, to_2tuple, trunc_normal_
         | 
| 19 | 
            +
            from .modeling_enc_dec import (
         | 
| 20 | 
            +
                DecoderOutput, DiagonalGaussianDistribution, 
         | 
| 21 | 
            +
                CausalVaeDecoder, CausalVaeEncoder,
         | 
| 22 | 
            +
            )
         | 
| 23 | 
            +
            from .modeling_causal_conv import CausalConv3d
         | 
| 24 | 
            +
            from IPython import embed
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            from utils import (
         | 
| 27 | 
            +
                is_context_parallel_initialized,
         | 
| 28 | 
            +
                get_context_parallel_group,
         | 
| 29 | 
            +
                get_context_parallel_world_size,
         | 
| 30 | 
            +
                get_context_parallel_rank,
         | 
| 31 | 
            +
                get_context_parallel_group_rank,
         | 
| 32 | 
            +
            )
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            from .context_parallel_ops import (
         | 
| 35 | 
            +
                conv_scatter_to_context_parallel_region,
         | 
| 36 | 
            +
                conv_gather_from_context_parallel_region,
         | 
| 37 | 
            +
            )
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            class CausalVideoVAE(ModelMixin, ConfigMixin):
         | 
| 41 | 
            +
                r"""
         | 
| 42 | 
            +
                A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
         | 
| 45 | 
            +
                for all models (such as downloading or saving).
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                Parameters:
         | 
| 48 | 
            +
                    in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
         | 
| 49 | 
            +
                    out_channels (int,  *optional*, defaults to 3): Number of channels in the output.
         | 
| 50 | 
            +
                    down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
         | 
| 51 | 
            +
                        Tuple of downsample block types.
         | 
| 52 | 
            +
                    up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
         | 
| 53 | 
            +
                        Tuple of upsample block types.
         | 
| 54 | 
            +
                    block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
         | 
| 55 | 
            +
                        Tuple of block output channels.
         | 
| 56 | 
            +
                    act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
         | 
| 57 | 
            +
                    latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
         | 
| 58 | 
            +
                    sample_size (`int`, *optional*, defaults to `32`): Sample input size.
         | 
| 59 | 
            +
                    scaling_factor (`float`, *optional*, defaults to 0.18215):
         | 
| 60 | 
            +
                        The component-wise standard deviation of the trained latent space computed using the first batch of the
         | 
| 61 | 
            +
                        training set. This is used to scale the latent space to have unit variance when training the diffusion
         | 
| 62 | 
            +
                        model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
         | 
| 63 | 
            +
                        diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
         | 
| 64 | 
            +
                        / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
         | 
| 65 | 
            +
                        Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
         | 
| 66 | 
            +
                    force_upcast (`bool`, *optional*, default to `True`):
         | 
| 67 | 
            +
                        If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
         | 
| 68 | 
            +
                        can be fine-tuned / trained to a lower range without loosing too much precision in which case
         | 
| 69 | 
            +
                        `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
         | 
| 70 | 
            +
                """
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                _supports_gradient_checkpointing = True
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                @register_to_config
         | 
| 75 | 
            +
                def __init__(
         | 
| 76 | 
            +
                    self,
         | 
| 77 | 
            +
                    # encoder related parameters
         | 
| 78 | 
            +
                    encoder_in_channels: int = 3,
         | 
| 79 | 
            +
                    encoder_out_channels: int = 4,
         | 
| 80 | 
            +
                    encoder_layers_per_block: Tuple[int, ...] = (2, 2, 2, 2),
         | 
| 81 | 
            +
                    encoder_down_block_types: Tuple[str, ...] = (
         | 
| 82 | 
            +
                        "DownEncoderBlockCausal3D",
         | 
| 83 | 
            +
                        "DownEncoderBlockCausal3D",
         | 
| 84 | 
            +
                        "DownEncoderBlockCausal3D",
         | 
| 85 | 
            +
                        "DownEncoderBlockCausal3D",
         | 
| 86 | 
            +
                    ),
         | 
| 87 | 
            +
                    encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
         | 
| 88 | 
            +
                    encoder_spatial_down_sample: Tuple[bool, ...] = (True, True, True, False),
         | 
| 89 | 
            +
                    encoder_temporal_down_sample: Tuple[bool, ...] = (True, True, True, False),
         | 
| 90 | 
            +
                    encoder_block_dropout: Tuple[int, ...] = (0.0, 0.0, 0.0, 0.0),
         | 
| 91 | 
            +
                    encoder_act_fn: str = "silu",
         | 
| 92 | 
            +
                    encoder_norm_num_groups: int = 32,
         | 
| 93 | 
            +
                    encoder_double_z: bool = True,
         | 
| 94 | 
            +
                    encoder_type: str = 'causal_vae_conv',
         | 
| 95 | 
            +
                    # decoder related
         | 
| 96 | 
            +
                    decoder_in_channels: int = 4,
         | 
| 97 | 
            +
                    decoder_out_channels: int = 3,
         | 
| 98 | 
            +
                    decoder_layers_per_block: Tuple[int, ...] = (3, 3, 3, 3),
         | 
| 99 | 
            +
                    decoder_up_block_types: Tuple[str, ...] = (
         | 
| 100 | 
            +
                        "UpDecoderBlockCausal3D",
         | 
| 101 | 
            +
                        "UpDecoderBlockCausal3D",
         | 
| 102 | 
            +
                        "UpDecoderBlockCausal3D",
         | 
| 103 | 
            +
                        "UpDecoderBlockCausal3D",
         | 
| 104 | 
            +
                    ),
         | 
| 105 | 
            +
                    decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
         | 
| 106 | 
            +
                    decoder_spatial_up_sample: Tuple[bool, ...] = (True, True, True, False),
         | 
| 107 | 
            +
                    decoder_temporal_up_sample: Tuple[bool, ...] = (True, True, True, False),
         | 
| 108 | 
            +
                    decoder_block_dropout: Tuple[int, ...] = (0.0, 0.0, 0.0, 0.0),
         | 
| 109 | 
            +
                    decoder_act_fn: str = "silu",
         | 
| 110 | 
            +
                    decoder_norm_num_groups: int = 32,
         | 
| 111 | 
            +
                    decoder_type: str = 'causal_vae_conv',
         | 
| 112 | 
            +
                    sample_size: int = 256,
         | 
| 113 | 
            +
                    scaling_factor: float = 0.18215,
         | 
| 114 | 
            +
                    add_post_quant_conv: bool = True,
         | 
| 115 | 
            +
                    interpolate: bool = False,
         | 
| 116 | 
            +
                    downsample_scale: int = 8,
         | 
| 117 | 
            +
                ):
         | 
| 118 | 
            +
                    super().__init__()
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    print(f"The latent dimmension channes is {encoder_out_channels}")
         | 
| 121 | 
            +
                    # pass init params to Encoder
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    self.encoder = CausalVaeEncoder(
         | 
| 124 | 
            +
                        in_channels=encoder_in_channels,
         | 
| 125 | 
            +
                        out_channels=encoder_out_channels,
         | 
| 126 | 
            +
                        down_block_types=encoder_down_block_types,
         | 
| 127 | 
            +
                        spatial_down_sample=encoder_spatial_down_sample,
         | 
| 128 | 
            +
                        temporal_down_sample=encoder_temporal_down_sample,
         | 
| 129 | 
            +
                        block_out_channels=encoder_block_out_channels,
         | 
| 130 | 
            +
                        layers_per_block=encoder_layers_per_block,
         | 
| 131 | 
            +
                        act_fn=encoder_act_fn,
         | 
| 132 | 
            +
                        norm_num_groups=encoder_norm_num_groups,
         | 
| 133 | 
            +
                        double_z=True,
         | 
| 134 | 
            +
                        block_dropout=encoder_block_dropout,
         | 
| 135 | 
            +
                    )
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    # pass init params to Decoder
         | 
| 138 | 
            +
                    self.decoder = CausalVaeDecoder(
         | 
| 139 | 
            +
                        in_channels=decoder_in_channels,
         | 
| 140 | 
            +
                        out_channels=decoder_out_channels,
         | 
| 141 | 
            +
                        up_block_types=decoder_up_block_types,
         | 
| 142 | 
            +
                        spatial_up_sample=decoder_spatial_up_sample,
         | 
| 143 | 
            +
                        temporal_up_sample=decoder_temporal_up_sample,
         | 
| 144 | 
            +
                        block_out_channels=decoder_block_out_channels,
         | 
| 145 | 
            +
                        layers_per_block=decoder_layers_per_block,
         | 
| 146 | 
            +
                        norm_num_groups=decoder_norm_num_groups,
         | 
| 147 | 
            +
                        act_fn=decoder_act_fn,
         | 
| 148 | 
            +
                        interpolate=interpolate,
         | 
| 149 | 
            +
                        block_dropout=decoder_block_dropout,
         | 
| 150 | 
            +
                    )
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    self.quant_conv = CausalConv3d(2 * encoder_out_channels, 2 * encoder_out_channels, kernel_size=1, stride=1)
         | 
| 153 | 
            +
                    self.post_quant_conv = CausalConv3d(encoder_out_channels, encoder_out_channels, kernel_size=1, stride=1)
         | 
| 154 | 
            +
                    self.use_tiling = False
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    # only relevant if vae tiling is enabled
         | 
| 157 | 
            +
                    self.tile_sample_min_size = self.config.sample_size
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    sample_size = (
         | 
| 160 | 
            +
                        self.config.sample_size[0]
         | 
| 161 | 
            +
                        if isinstance(self.config.sample_size, (list, tuple))
         | 
| 162 | 
            +
                        else self.config.sample_size
         | 
| 163 | 
            +
                    )
         | 
| 164 | 
            +
                    self.tile_latent_min_size = int(sample_size / downsample_scale) 
         | 
| 165 | 
            +
                    self.encode_tile_overlap_factor = 1 / 8
         | 
| 166 | 
            +
                    self.decode_tile_overlap_factor = 1 / 8
         | 
| 167 | 
            +
                    self.downsample_scale = downsample_scale
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    self.apply(self._init_weights)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                def _init_weights(self, m):
         | 
| 172 | 
            +
                    if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)):
         | 
| 173 | 
            +
                        trunc_normal_(m.weight, std=.02)
         | 
| 174 | 
            +
                        if m.bias is not None:
         | 
| 175 | 
            +
                            nn.init.constant_(m.bias, 0)
         | 
| 176 | 
            +
                    elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
         | 
| 177 | 
            +
                        nn.init.constant_(m.bias, 0)
         | 
| 178 | 
            +
                        nn.init.constant_(m.weight, 1.0)
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                def _set_gradient_checkpointing(self, module, value=False):
         | 
| 181 | 
            +
                    if isinstance(module, (Encoder, Decoder)):
         | 
| 182 | 
            +
                        module.gradient_checkpointing = value
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                def enable_tiling(self, use_tiling: bool = True):
         | 
| 185 | 
            +
                    r"""
         | 
| 186 | 
            +
                    Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
         | 
| 187 | 
            +
                    compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
         | 
| 188 | 
            +
                    processing larger images.
         | 
| 189 | 
            +
                    """
         | 
| 190 | 
            +
                    self.use_tiling = use_tiling
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                def disable_tiling(self):
         | 
| 193 | 
            +
                    r"""
         | 
| 194 | 
            +
                    Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
         | 
| 195 | 
            +
                    decoding in one step.
         | 
| 196 | 
            +
                    """
         | 
| 197 | 
            +
                    self.enable_tiling(False)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                @property
         | 
| 200 | 
            +
                # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
         | 
| 201 | 
            +
                def attn_processors(self) -> Dict[str, AttentionProcessor]:
         | 
| 202 | 
            +
                    r"""
         | 
| 203 | 
            +
                    Returns:
         | 
| 204 | 
            +
                        `dict` of attention processors: A dictionary containing all attention processors used in the model with
         | 
| 205 | 
            +
                        indexed by its weight name.
         | 
| 206 | 
            +
                    """
         | 
| 207 | 
            +
                    # set recursively
         | 
| 208 | 
            +
                    processors = {}
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
         | 
| 211 | 
            +
                        if hasattr(module, "get_processor"):
         | 
| 212 | 
            +
                            processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                        for sub_name, child in module.named_children():
         | 
| 215 | 
            +
                            fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                        return processors
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    for name, module in self.named_children():
         | 
| 220 | 
            +
                        fn_recursive_add_processors(name, module, processors)
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                    return processors
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
         | 
| 225 | 
            +
                def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
         | 
| 226 | 
            +
                    r"""
         | 
| 227 | 
            +
                    Sets the attention processor to use to compute attention.
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    Parameters:
         | 
| 230 | 
            +
                        processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
         | 
| 231 | 
            +
                            The instantiated processor class or a dictionary of processor classes that will be set as the processor
         | 
| 232 | 
            +
                            for **all** `Attention` layers.
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                            If `processor` is a dict, the key needs to define the path to the corresponding cross attention
         | 
| 235 | 
            +
                            processor. This is strongly recommended when setting trainable attention processors.
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    """
         | 
| 238 | 
            +
                    count = len(self.attn_processors.keys())
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    if isinstance(processor, dict) and len(processor) != count:
         | 
| 241 | 
            +
                        raise ValueError(
         | 
| 242 | 
            +
                            f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
         | 
| 243 | 
            +
                            f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
         | 
| 244 | 
            +
                        )
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
         | 
| 247 | 
            +
                        if hasattr(module, "set_processor"):
         | 
| 248 | 
            +
                            if not isinstance(processor, dict):
         | 
| 249 | 
            +
                                module.set_processor(processor)
         | 
| 250 | 
            +
                            else:
         | 
| 251 | 
            +
                                module.set_processor(processor.pop(f"{name}.processor"))
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                        for sub_name, child in module.named_children():
         | 
| 254 | 
            +
                            fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    for name, module in self.named_children():
         | 
| 257 | 
            +
                        fn_recursive_attn_processor(name, module, processor)
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
         | 
| 260 | 
            +
                def set_default_attn_processor(self):
         | 
| 261 | 
            +
                    """
         | 
| 262 | 
            +
                    Disables custom attention processors and sets the default attention implementation.
         | 
| 263 | 
            +
                    """
         | 
| 264 | 
            +
                    if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
         | 
| 265 | 
            +
                        processor = AttnAddedKVProcessor()
         | 
| 266 | 
            +
                    elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
         | 
| 267 | 
            +
                        processor = AttnProcessor()
         | 
| 268 | 
            +
                    else:
         | 
| 269 | 
            +
                        raise ValueError(
         | 
| 270 | 
            +
                            f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
         | 
| 271 | 
            +
                        )
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    self.set_attn_processor(processor)
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                def encode(
         | 
| 276 | 
            +
                    self, x: torch.FloatTensor, return_dict: bool = True,
         | 
| 277 | 
            +
                    is_init_image=True, temporal_chunk=False, window_size=16, tile_sample_min_size=256,
         | 
| 278 | 
            +
                ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
         | 
| 279 | 
            +
                    """
         | 
| 280 | 
            +
                    Encode a batch of images into latents.
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    Args:
         | 
| 283 | 
            +
                        x (`torch.FloatTensor`): Input batch of images.
         | 
| 284 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 285 | 
            +
                            Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                    Returns:
         | 
| 288 | 
            +
                            The latent representations of the encoded images. If `return_dict` is True, a
         | 
| 289 | 
            +
                            [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
         | 
| 290 | 
            +
                    """
         | 
| 291 | 
            +
                    self.tile_sample_min_size = tile_sample_min_size
         | 
| 292 | 
            +
                    self.tile_latent_min_size = int(tile_sample_min_size / self.downsample_scale)
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                    if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
         | 
| 295 | 
            +
                        return self.tiled_encode(x, return_dict=return_dict, is_init_image=is_init_image, 
         | 
| 296 | 
            +
                            temporal_chunk=temporal_chunk, window_size=window_size)
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    if temporal_chunk:
         | 
| 299 | 
            +
                        moments = self.chunk_encode(x, window_size=window_size)
         | 
| 300 | 
            +
                    else:
         | 
| 301 | 
            +
                        h = self.encoder(x, is_init_image=is_init_image, temporal_chunk=False)
         | 
| 302 | 
            +
                        moments = self.quant_conv(h, is_init_image=is_init_image, temporal_chunk=False)
         | 
| 303 | 
            +
                
         | 
| 304 | 
            +
                    posterior = DiagonalGaussianDistribution(moments)
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                    if not return_dict:
         | 
| 307 | 
            +
                        return (posterior,)
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                    return AutoencoderKLOutput(latent_dist=posterior)
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                @torch.no_grad()
         | 
| 312 | 
            +
                def chunk_encode(self, x: torch.FloatTensor, window_size=16):
         | 
| 313 | 
            +
                    # Only used during inference
         | 
| 314 | 
            +
                    # Encode a long video clips through sliding window
         | 
| 315 | 
            +
                    num_frames = x.shape[2]
         | 
| 316 | 
            +
                    assert (num_frames - 1) % self.downsample_scale == 0
         | 
| 317 | 
            +
                    init_window_size = window_size + 1
         | 
| 318 | 
            +
                    frame_list = [x[:,:,:init_window_size]]
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    # To chunk the long video 
         | 
| 321 | 
            +
                    full_chunk_size = (num_frames - init_window_size) // window_size
         | 
| 322 | 
            +
                    fid = init_window_size
         | 
| 323 | 
            +
                    for idx in range(full_chunk_size):
         | 
| 324 | 
            +
                        frame_list.append(x[:, :, fid:fid+window_size])
         | 
| 325 | 
            +
                        fid += window_size
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                    if fid < num_frames:
         | 
| 328 | 
            +
                        frame_list.append(x[:, :, fid:])
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                    latent_list = []
         | 
| 331 | 
            +
                    for idx, frames in enumerate(frame_list):
         | 
| 332 | 
            +
                        if idx == 0:
         | 
| 333 | 
            +
                            h = self.encoder(frames, is_init_image=True, temporal_chunk=True)
         | 
| 334 | 
            +
                            moments = self.quant_conv(h, is_init_image=True, temporal_chunk=True)
         | 
| 335 | 
            +
                        else:
         | 
| 336 | 
            +
                            h = self.encoder(frames, is_init_image=False, temporal_chunk=True)
         | 
| 337 | 
            +
                            moments = self.quant_conv(h, is_init_image=False, temporal_chunk=True)
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                        latent_list.append(moments)
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                    latent = torch.cat(latent_list, dim=2)
         | 
| 342 | 
            +
                    return latent
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                def get_last_layer(self):
         | 
| 345 | 
            +
                    return self.decoder.conv_out.conv.weight
         | 
| 346 | 
            +
                
         | 
| 347 | 
            +
                @torch.no_grad()
         | 
| 348 | 
            +
                def chunk_decode(self, z: torch.FloatTensor, window_size=2):
         | 
| 349 | 
            +
                    num_frames = z.shape[2]
         | 
| 350 | 
            +
                    init_window_size = window_size + 1
         | 
| 351 | 
            +
                    frame_list = [z[:,:,:init_window_size]]
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                    # To chunk the long video 
         | 
| 354 | 
            +
                    full_chunk_size = (num_frames - init_window_size) // window_size
         | 
| 355 | 
            +
                    fid = init_window_size
         | 
| 356 | 
            +
                    for idx in range(full_chunk_size):
         | 
| 357 | 
            +
                        frame_list.append(z[:, :, fid:fid+window_size])
         | 
| 358 | 
            +
                        fid += window_size
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                    if fid < num_frames:
         | 
| 361 | 
            +
                        frame_list.append(z[:, :, fid:])
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                    dec_list = []
         | 
| 364 | 
            +
                    for idx, frames in enumerate(frame_list):
         | 
| 365 | 
            +
                        if idx == 0:
         | 
| 366 | 
            +
                            z_h = self.post_quant_conv(frames, is_init_image=True, temporal_chunk=True)
         | 
| 367 | 
            +
                            dec = self.decoder(z_h, is_init_image=True, temporal_chunk=True)
         | 
| 368 | 
            +
                        else:
         | 
| 369 | 
            +
                            z_h = self.post_quant_conv(frames, is_init_image=False, temporal_chunk=True)
         | 
| 370 | 
            +
                            dec = self.decoder(z_h, is_init_image=False, temporal_chunk=True)
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                        dec_list.append(dec)
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                    dec = torch.cat(dec_list, dim=2)
         | 
| 375 | 
            +
                    return dec
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                def decode(self, z: torch.FloatTensor, is_init_image=True, temporal_chunk=False, 
         | 
| 378 | 
            +
                        return_dict: bool = True, window_size: int = 2, tile_sample_min_size: int = 256,) -> Union[DecoderOutput, torch.FloatTensor]:
         | 
| 379 | 
            +
                    
         | 
| 380 | 
            +
                    self.tile_sample_min_size = tile_sample_min_size
         | 
| 381 | 
            +
                    self.tile_latent_min_size = int(tile_sample_min_size / self.downsample_scale)
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                    if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
         | 
| 384 | 
            +
                        return self.tiled_decode(z, is_init_image=is_init_image, 
         | 
| 385 | 
            +
                                temporal_chunk=temporal_chunk, window_size=window_size, return_dict=return_dict)
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                    if temporal_chunk:
         | 
| 388 | 
            +
                        dec = self.chunk_decode(z, window_size=window_size)
         | 
| 389 | 
            +
                    else:
         | 
| 390 | 
            +
                        z = self.post_quant_conv(z, is_init_image=is_init_image, temporal_chunk=False)
         | 
| 391 | 
            +
                        dec = self.decoder(z, is_init_image=is_init_image, temporal_chunk=False)
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                    if not return_dict:
         | 
| 394 | 
            +
                        return (dec,)
         | 
| 395 | 
            +
             | 
| 396 | 
            +
                    return DecoderOutput(sample=dec)
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
         | 
| 399 | 
            +
                    blend_extent = min(a.shape[3], b.shape[3], blend_extent)
         | 
| 400 | 
            +
                    for y in range(blend_extent):
         | 
| 401 | 
            +
                        b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
         | 
| 402 | 
            +
                    return b
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
         | 
| 405 | 
            +
                    blend_extent = min(a.shape[4], b.shape[4], blend_extent)
         | 
| 406 | 
            +
                    for x in range(blend_extent):
         | 
| 407 | 
            +
                        b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
         | 
| 408 | 
            +
                    return b
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True, 
         | 
| 411 | 
            +
                        is_init_image=True, temporal_chunk=False, window_size=16,) -> AutoencoderKLOutput:
         | 
| 412 | 
            +
                    r"""Encode a batch of images using a tiled encoder.
         | 
| 413 | 
            +
             | 
| 414 | 
            +
                    When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
         | 
| 415 | 
            +
                    steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
         | 
| 416 | 
            +
                    different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
         | 
| 417 | 
            +
                    tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
         | 
| 418 | 
            +
                    output, but they should be much less noticeable.
         | 
| 419 | 
            +
             | 
| 420 | 
            +
                    Args:
         | 
| 421 | 
            +
                        x (`torch.FloatTensor`): Input batch of images.
         | 
| 422 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 423 | 
            +
                            Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                    Returns:
         | 
| 426 | 
            +
                        [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
         | 
| 427 | 
            +
                            If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
         | 
| 428 | 
            +
                            `tuple` is returned.
         | 
| 429 | 
            +
                    """
         | 
| 430 | 
            +
                    overlap_size = int(self.tile_sample_min_size * (1 - self.encode_tile_overlap_factor))
         | 
| 431 | 
            +
                    blend_extent = int(self.tile_latent_min_size * self.encode_tile_overlap_factor)
         | 
| 432 | 
            +
                    row_limit = self.tile_latent_min_size - blend_extent
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                    # Split the image into 512x512 tiles and encode them separately.
         | 
| 435 | 
            +
                    rows = []
         | 
| 436 | 
            +
                    for i in range(0, x.shape[3], overlap_size):
         | 
| 437 | 
            +
                        row = []
         | 
| 438 | 
            +
                        for j in range(0, x.shape[4], overlap_size):
         | 
| 439 | 
            +
                            tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
         | 
| 440 | 
            +
                            if temporal_chunk:
         | 
| 441 | 
            +
                                tile = self.chunk_encode(tile, window_size=window_size)
         | 
| 442 | 
            +
                            else:
         | 
| 443 | 
            +
                                tile = self.encoder(tile, is_init_image=True, temporal_chunk=False)
         | 
| 444 | 
            +
                                tile = self.quant_conv(tile, is_init_image=True, temporal_chunk=False)
         | 
| 445 | 
            +
                            row.append(tile)
         | 
| 446 | 
            +
                        rows.append(row)
         | 
| 447 | 
            +
                    result_rows = []
         | 
| 448 | 
            +
                    for i, row in enumerate(rows):
         | 
| 449 | 
            +
                        result_row = []
         | 
| 450 | 
            +
                        for j, tile in enumerate(row):
         | 
| 451 | 
            +
                            # blend the above tile and the left tile
         | 
| 452 | 
            +
                            # to the current tile and add the current tile to the result row
         | 
| 453 | 
            +
                            if i > 0:
         | 
| 454 | 
            +
                                tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
         | 
| 455 | 
            +
                            if j > 0:
         | 
| 456 | 
            +
                                tile = self.blend_h(row[j - 1], tile, blend_extent)
         | 
| 457 | 
            +
                            result_row.append(tile[:, :, :, :row_limit, :row_limit])
         | 
| 458 | 
            +
                        result_rows.append(torch.cat(result_row, dim=4))
         | 
| 459 | 
            +
             | 
| 460 | 
            +
                    moments = torch.cat(result_rows, dim=3)
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                    posterior = DiagonalGaussianDistribution(moments)
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                    if not return_dict:
         | 
| 465 | 
            +
                        return (posterior,)
         | 
| 466 | 
            +
             | 
| 467 | 
            +
                    return AutoencoderKLOutput(latent_dist=posterior)
         | 
| 468 | 
            +
             | 
| 469 | 
            +
                def tiled_decode(self, z: torch.FloatTensor, is_init_image=True, 
         | 
| 470 | 
            +
                        temporal_chunk=False, window_size=2, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
         | 
| 471 | 
            +
                    r"""
         | 
| 472 | 
            +
                    Decode a batch of images using a tiled decoder.
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                    Args:
         | 
| 475 | 
            +
                        z (`torch.FloatTensor`): Input batch of latent vectors.
         | 
| 476 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 477 | 
            +
                            Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
         | 
| 478 | 
            +
             | 
| 479 | 
            +
                    Returns:
         | 
| 480 | 
            +
                        [`~models.vae.DecoderOutput`] or `tuple`:
         | 
| 481 | 
            +
                            If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
         | 
| 482 | 
            +
                            returned.
         | 
| 483 | 
            +
                    """
         | 
| 484 | 
            +
                    overlap_size = int(self.tile_latent_min_size * (1 - self.decode_tile_overlap_factor))
         | 
| 485 | 
            +
                    blend_extent = int(self.tile_sample_min_size * self.decode_tile_overlap_factor)
         | 
| 486 | 
            +
                    row_limit = self.tile_sample_min_size - blend_extent
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                    # Split z into overlapping 64x64 tiles and decode them separately.
         | 
| 489 | 
            +
                    # The tiles have an overlap to avoid seams between tiles.
         | 
| 490 | 
            +
                    rows = []
         | 
| 491 | 
            +
                    for i in range(0, z.shape[3], overlap_size):
         | 
| 492 | 
            +
                        row = []
         | 
| 493 | 
            +
                        for j in range(0, z.shape[4], overlap_size):
         | 
| 494 | 
            +
                            tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
         | 
| 495 | 
            +
                            if temporal_chunk:
         | 
| 496 | 
            +
                                decoded = self.chunk_decode(tile, window_size=window_size)
         | 
| 497 | 
            +
                            else:
         | 
| 498 | 
            +
                                tile = self.post_quant_conv(tile, is_init_image=True, temporal_chunk=False)
         | 
| 499 | 
            +
                                decoded = self.decoder(tile, is_init_image=True, temporal_chunk=False)
         | 
| 500 | 
            +
                            row.append(decoded)
         | 
| 501 | 
            +
                        rows.append(row)
         | 
| 502 | 
            +
                    result_rows = []
         | 
| 503 | 
            +
             | 
| 504 | 
            +
                    for i, row in enumerate(rows):
         | 
| 505 | 
            +
                        result_row = []
         | 
| 506 | 
            +
                        for j, tile in enumerate(row):
         | 
| 507 | 
            +
                            # blend the above tile and the left tile
         | 
| 508 | 
            +
                            # to the current tile and add the current tile to the result row
         | 
| 509 | 
            +
                            if i > 0:
         | 
| 510 | 
            +
                                tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
         | 
| 511 | 
            +
                            if j > 0:
         | 
| 512 | 
            +
                                tile = self.blend_h(row[j - 1], tile, blend_extent)
         | 
| 513 | 
            +
                            result_row.append(tile[:, :, :, :row_limit, :row_limit])
         | 
| 514 | 
            +
                        result_rows.append(torch.cat(result_row, dim=4))
         | 
| 515 | 
            +
             | 
| 516 | 
            +
                    dec = torch.cat(result_rows, dim=3)
         | 
| 517 | 
            +
                    if not return_dict:
         | 
| 518 | 
            +
                        return (dec,)
         | 
| 519 | 
            +
             | 
| 520 | 
            +
                    return DecoderOutput(sample=dec)
         | 
| 521 | 
            +
             | 
| 522 | 
            +
                def forward(
         | 
| 523 | 
            +
                    self,
         | 
| 524 | 
            +
                    sample: torch.FloatTensor,
         | 
| 525 | 
            +
                    sample_posterior: bool = True,
         | 
| 526 | 
            +
                    generator: Optional[torch.Generator] = None,
         | 
| 527 | 
            +
                    freeze_encoder: bool = False,
         | 
| 528 | 
            +
                    is_init_image=True, 
         | 
| 529 | 
            +
                    temporal_chunk=False,
         | 
| 530 | 
            +
                ) -> Union[DecoderOutput, torch.FloatTensor]:
         | 
| 531 | 
            +
                    r"""
         | 
| 532 | 
            +
                    Args:
         | 
| 533 | 
            +
                        sample (`torch.FloatTensor`): Input sample.
         | 
| 534 | 
            +
                        sample_posterior (`bool`, *optional*, defaults to `False`):
         | 
| 535 | 
            +
                            Whether to sample from the posterior.
         | 
| 536 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 537 | 
            +
                            Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
         | 
| 538 | 
            +
                    """
         | 
| 539 | 
            +
                    x = sample
         | 
| 540 | 
            +
             | 
| 541 | 
            +
                    if is_context_parallel_initialized():
         | 
| 542 | 
            +
                        assert self.training, "Only supports during training now"
         | 
| 543 | 
            +
             | 
| 544 | 
            +
                        if freeze_encoder:
         | 
| 545 | 
            +
                            with torch.no_grad():
         | 
| 546 | 
            +
                                h = self.encoder(x, is_init_image=True, temporal_chunk=False)
         | 
| 547 | 
            +
                                moments = self.quant_conv(h, is_init_image=True, temporal_chunk=False)
         | 
| 548 | 
            +
                                posterior = DiagonalGaussianDistribution(moments)
         | 
| 549 | 
            +
                                global_posterior = posterior
         | 
| 550 | 
            +
                        else:
         | 
| 551 | 
            +
                            h = self.encoder(x, is_init_image=True, temporal_chunk=False)
         | 
| 552 | 
            +
                            moments = self.quant_conv(h, is_init_image=True, temporal_chunk=False)
         | 
| 553 | 
            +
                            posterior = DiagonalGaussianDistribution(moments)
         | 
| 554 | 
            +
                            global_moments = conv_gather_from_context_parallel_region(moments, dim=2, kernel_size=1)
         | 
| 555 | 
            +
                            global_posterior = DiagonalGaussianDistribution(global_moments)
         | 
| 556 | 
            +
                        
         | 
| 557 | 
            +
                        if sample_posterior:
         | 
| 558 | 
            +
                            z = posterior.sample(generator=generator)
         | 
| 559 | 
            +
                        else:
         | 
| 560 | 
            +
                            z = posterior.mode()
         | 
| 561 | 
            +
             | 
| 562 | 
            +
                        if get_context_parallel_rank() == 0:
         | 
| 563 | 
            +
                            dec = self.decode(z, is_init_image=True).sample
         | 
| 564 | 
            +
                        else:
         | 
| 565 | 
            +
                            # Do not drop the first upsampled frame
         | 
| 566 | 
            +
                            dec = self.decode(z, is_init_image=False).sample
         | 
| 567 | 
            +
             | 
| 568 | 
            +
                        return global_posterior, dec
         | 
| 569 | 
            +
             | 
| 570 | 
            +
                    else:
         | 
| 571 | 
            +
                        # The normal training
         | 
| 572 | 
            +
                        if freeze_encoder:
         | 
| 573 | 
            +
                            with torch.no_grad():
         | 
| 574 | 
            +
                                posterior = self.encode(x, is_init_image=is_init_image, 
         | 
| 575 | 
            +
                                        temporal_chunk=temporal_chunk).latent_dist
         | 
| 576 | 
            +
                        else:
         | 
| 577 | 
            +
                            posterior = self.encode(x, is_init_image=is_init_image, 
         | 
| 578 | 
            +
                                    temporal_chunk=temporal_chunk).latent_dist
         | 
| 579 | 
            +
                    
         | 
| 580 | 
            +
                        if sample_posterior:
         | 
| 581 | 
            +
                            z = posterior.sample(generator=generator)
         | 
| 582 | 
            +
                        else:
         | 
| 583 | 
            +
                            z = posterior.mode()
         | 
| 584 | 
            +
             | 
| 585 | 
            +
                        dec = self.decode(z, is_init_image=is_init_image, temporal_chunk=temporal_chunk).sample
         | 
| 586 | 
            +
             | 
| 587 | 
            +
                        return posterior, dec
         | 
| 588 | 
            +
             | 
| 589 | 
            +
                # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
         | 
| 590 | 
            +
                def fuse_qkv_projections(self):
         | 
| 591 | 
            +
                    """
         | 
| 592 | 
            +
                    Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
         | 
| 593 | 
            +
                    key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
         | 
| 594 | 
            +
             | 
| 595 | 
            +
                    <Tip warning={true}>
         | 
| 596 | 
            +
             | 
| 597 | 
            +
                    This API is 🧪 experimental.
         | 
| 598 | 
            +
             | 
| 599 | 
            +
                    </Tip>
         | 
| 600 | 
            +
                    """
         | 
| 601 | 
            +
                    self.original_attn_processors = None
         | 
| 602 | 
            +
             | 
| 603 | 
            +
                    for _, attn_processor in self.attn_processors.items():
         | 
| 604 | 
            +
                        if "Added" in str(attn_processor.__class__.__name__):
         | 
| 605 | 
            +
                            raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
         | 
| 606 | 
            +
             | 
| 607 | 
            +
                    self.original_attn_processors = self.attn_processors
         | 
| 608 | 
            +
             | 
| 609 | 
            +
                    for module in self.modules():
         | 
| 610 | 
            +
                        if isinstance(module, Attention):
         | 
| 611 | 
            +
                            module.fuse_projections(fuse=True)
         | 
| 612 | 
            +
             | 
| 613 | 
            +
                # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
         | 
| 614 | 
            +
                def unfuse_qkv_projections(self):
         | 
| 615 | 
            +
                    """Disables the fused QKV projection if enabled.
         | 
| 616 | 
            +
             | 
| 617 | 
            +
                    <Tip warning={true}>
         | 
| 618 | 
            +
             | 
| 619 | 
            +
                    This API is 🧪 experimental.
         | 
| 620 | 
            +
             | 
| 621 | 
            +
                    </Tip>
         | 
| 622 | 
            +
             | 
| 623 | 
            +
                    """
         | 
| 624 | 
            +
                    if self.original_attn_processors is not None:
         | 
| 625 | 
            +
                        self.set_attn_processor(self.original_attn_processors)
         | 
    	
        video_vae/modeling_discriminator.py
    ADDED
    
    | @@ -0,0 +1,122 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import functools
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            from einops import rearrange
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            def weights_init(m):
         | 
| 8 | 
            +
                classname = m.__class__.__name__
         | 
| 9 | 
            +
                if classname.find('Conv') != -1:
         | 
| 10 | 
            +
                    nn.init.normal_(m.weight.data, 0.0, 0.02)
         | 
| 11 | 
            +
                    nn.init.constant_(m.bias.data, 0)
         | 
| 12 | 
            +
                elif classname.find('BatchNorm') != -1:
         | 
| 13 | 
            +
                    nn.init.normal_(m.weight.data, 1.0, 0.02)
         | 
| 14 | 
            +
                    nn.init.constant_(m.bias.data, 0)
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            class NLayerDiscriminator(nn.Module):
         | 
| 18 | 
            +
                """Defines a PatchGAN discriminator as in Pix2Pix
         | 
| 19 | 
            +
                    --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
         | 
| 20 | 
            +
                """
         | 
| 21 | 
            +
                def __init__(self, input_nc=3, ndf=64, n_layers=4):
         | 
| 22 | 
            +
                    """Construct a PatchGAN discriminator
         | 
| 23 | 
            +
                    Parameters:
         | 
| 24 | 
            +
                        input_nc (int)  -- the number of channels in input images
         | 
| 25 | 
            +
                        ndf (int)       -- the number of filters in the last conv layer
         | 
| 26 | 
            +
                        n_layers (int)  -- the number of conv layers in the discriminator
         | 
| 27 | 
            +
                        norm_layer      -- normalization layer
         | 
| 28 | 
            +
                    """
         | 
| 29 | 
            +
                    super(NLayerDiscriminator, self).__init__()
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    # norm_layer = nn.BatchNorm2d
         | 
| 32 | 
            +
                    norm_layer = nn.InstanceNorm2d
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
         | 
| 35 | 
            +
                        use_bias = norm_layer.func != nn.BatchNorm2d
         | 
| 36 | 
            +
                    else:
         | 
| 37 | 
            +
                        use_bias = norm_layer != nn.BatchNorm2d
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    kw = 4
         | 
| 40 | 
            +
                    padw = 1
         | 
| 41 | 
            +
                    sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
         | 
| 42 | 
            +
                    nf_mult = 1
         | 
| 43 | 
            +
                    nf_mult_prev = 1
         | 
| 44 | 
            +
                    for n in range(1, n_layers):  # gradually increase the number of filters
         | 
| 45 | 
            +
                        nf_mult_prev = nf_mult
         | 
| 46 | 
            +
                        nf_mult = min(2 ** n, 8)
         | 
| 47 | 
            +
                        sequence += [
         | 
| 48 | 
            +
                            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
         | 
| 49 | 
            +
                            norm_layer(ndf * nf_mult),
         | 
| 50 | 
            +
                            nn.LeakyReLU(0.2, True)
         | 
| 51 | 
            +
                        ]
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    nf_mult_prev = nf_mult
         | 
| 54 | 
            +
                    nf_mult = min(2 ** n_layers, 8)
         | 
| 55 | 
            +
                    sequence += [
         | 
| 56 | 
            +
                        nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
         | 
| 57 | 
            +
                        norm_layer(ndf * nf_mult),
         | 
| 58 | 
            +
                        nn.LeakyReLU(0.2, True)
         | 
| 59 | 
            +
                    ]
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    sequence += [
         | 
| 62 | 
            +
                        nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
         | 
| 63 | 
            +
                    self.main = nn.Sequential(*sequence)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                def forward(self, input):
         | 
| 66 | 
            +
                    """Standard forward."""
         | 
| 67 | 
            +
                    return self.main(input)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            class NLayerDiscriminator3D(nn.Module):
         | 
| 71 | 
            +
                """Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs."""
         | 
| 72 | 
            +
                def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
         | 
| 73 | 
            +
                    """
         | 
| 74 | 
            +
                    Construct a 3D PatchGAN discriminator
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    Parameters:
         | 
| 77 | 
            +
                        input_nc (int)  -- the number of channels in input volumes
         | 
| 78 | 
            +
                        ndf (int)       -- the number of filters in the last conv layer
         | 
| 79 | 
            +
                        n_layers (int)  -- the number of conv layers in the discriminator
         | 
| 80 | 
            +
                        use_actnorm (bool) -- flag to use actnorm instead of batchnorm
         | 
| 81 | 
            +
                    """
         | 
| 82 | 
            +
                    super(NLayerDiscriminator3D, self).__init__()
         | 
| 83 | 
            +
                    # if not use_actnorm:
         | 
| 84 | 
            +
                    #     norm_layer = nn.BatchNorm3d
         | 
| 85 | 
            +
                    # else:
         | 
| 86 | 
            +
                    #     raise NotImplementedError("Not implemented.")
         | 
| 87 | 
            +
                    
         | 
| 88 | 
            +
                    norm_layer = nn.InstanceNorm3d
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    if type(norm_layer) == functools.partial:
         | 
| 91 | 
            +
                        use_bias = norm_layer.func != nn.BatchNorm3d
         | 
| 92 | 
            +
                    else:
         | 
| 93 | 
            +
                        use_bias = norm_layer != nn.BatchNorm3d
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    kw = 4
         | 
| 96 | 
            +
                    padw = 1
         | 
| 97 | 
            +
                    sequence = [nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
         | 
| 98 | 
            +
                    nf_mult = 1
         | 
| 99 | 
            +
                    nf_mult_prev = 1
         | 
| 100 | 
            +
                    for n in range(1, n_layers):  # gradually increase the number of filters
         | 
| 101 | 
            +
                        nf_mult_prev = nf_mult
         | 
| 102 | 
            +
                        nf_mult = min(2 ** n, 8)
         | 
| 103 | 
            +
                        sequence += [
         | 
| 104 | 
            +
                            nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=(1,2,2), padding=padw, bias=use_bias),
         | 
| 105 | 
            +
                            norm_layer(ndf * nf_mult),
         | 
| 106 | 
            +
                            nn.LeakyReLU(0.2, True)
         | 
| 107 | 
            +
                        ]
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    nf_mult_prev = nf_mult
         | 
| 110 | 
            +
                    nf_mult = min(2 ** n_layers, 8)
         | 
| 111 | 
            +
                    sequence += [
         | 
| 112 | 
            +
                        nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=1, padding=padw, bias=use_bias),
         | 
| 113 | 
            +
                        norm_layer(ndf * nf_mult),
         | 
| 114 | 
            +
                        nn.LeakyReLU(0.2, True)
         | 
| 115 | 
            +
                    ]
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    sequence += [nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
         | 
| 118 | 
            +
                    self.main = nn.Sequential(*sequence)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                def forward(self, input):
         | 
| 121 | 
            +
                    """Standard forward."""
         | 
| 122 | 
            +
                    return self.main(input)
         | 
    	
        video_vae/modeling_enc_dec.py
    ADDED
    
    | @@ -0,0 +1,422 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
            from dataclasses import dataclass
         | 
| 15 | 
            +
            from typing import Optional, Tuple
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import numpy as np
         | 
| 18 | 
            +
            import torch
         | 
| 19 | 
            +
            import torch.nn as nn
         | 
| 20 | 
            +
            from einops import rearrange
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from diffusers.utils import BaseOutput, is_torch_version
         | 
| 23 | 
            +
            from diffusers.utils.torch_utils import randn_tensor
         | 
| 24 | 
            +
            from diffusers.models.attention_processor import SpatialNorm
         | 
| 25 | 
            +
            from .modeling_block import (
         | 
| 26 | 
            +
                UNetMidBlock2D,
         | 
| 27 | 
            +
                CausalUNetMidBlock2D,
         | 
| 28 | 
            +
                get_down_block,
         | 
| 29 | 
            +
                get_up_block,
         | 
| 30 | 
            +
                get_input_layer,
         | 
| 31 | 
            +
                get_output_layer,
         | 
| 32 | 
            +
            )
         | 
| 33 | 
            +
            from .modeling_resnet import (
         | 
| 34 | 
            +
                Downsample2D, 
         | 
| 35 | 
            +
                Upsample2D, 
         | 
| 36 | 
            +
                TemporalDownsample2x, 
         | 
| 37 | 
            +
                TemporalUpsample2x,
         | 
| 38 | 
            +
            )
         | 
| 39 | 
            +
            from .modeling_causal_conv import CausalConv3d, CausalGroupNorm
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            @dataclass
         | 
| 43 | 
            +
            class DecoderOutput(BaseOutput):
         | 
| 44 | 
            +
                r"""
         | 
| 45 | 
            +
                Output of decoding method.
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                Args:
         | 
| 48 | 
            +
                    sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
         | 
| 49 | 
            +
                        The decoded output sample from the last layer of the model.
         | 
| 50 | 
            +
                """
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                sample: torch.FloatTensor
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            class CausalVaeEncoder(nn.Module):
         | 
| 56 | 
            +
                r"""
         | 
| 57 | 
            +
                The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                Args:
         | 
| 60 | 
            +
                    in_channels (`int`, *optional*, defaults to 3):
         | 
| 61 | 
            +
                        The number of input channels.
         | 
| 62 | 
            +
                    out_channels (`int`, *optional*, defaults to 3):
         | 
| 63 | 
            +
                        The number of output channels.
         | 
| 64 | 
            +
                    down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
         | 
| 65 | 
            +
                        The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
         | 
| 66 | 
            +
                        options.
         | 
| 67 | 
            +
                    block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
         | 
| 68 | 
            +
                        The number of output channels for each block.
         | 
| 69 | 
            +
                    layers_per_block (`int`, *optional*, defaults to 2):
         | 
| 70 | 
            +
                        The number of layers per block.
         | 
| 71 | 
            +
                    norm_num_groups (`int`, *optional*, defaults to 32):
         | 
| 72 | 
            +
                        The number of groups for normalization.
         | 
| 73 | 
            +
                    act_fn (`str`, *optional*, defaults to `"silu"`):
         | 
| 74 | 
            +
                        The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
         | 
| 75 | 
            +
                    double_z (`bool`, *optional*, defaults to `True`):
         | 
| 76 | 
            +
                        Whether to double the number of output channels for the last block.
         | 
| 77 | 
            +
                """
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                def __init__(
         | 
| 80 | 
            +
                    self,
         | 
| 81 | 
            +
                    in_channels: int = 3,
         | 
| 82 | 
            +
                    out_channels: int = 3,
         | 
| 83 | 
            +
                    down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
         | 
| 84 | 
            +
                    spatial_down_sample: Tuple[bool, ...] = (True,),
         | 
| 85 | 
            +
                    temporal_down_sample: Tuple[bool, ...] = (False,),
         | 
| 86 | 
            +
                    block_out_channels: Tuple[int, ...] = (64,),
         | 
| 87 | 
            +
                    layers_per_block: Tuple[int, ...] = (2,),
         | 
| 88 | 
            +
                    norm_num_groups: int = 32,
         | 
| 89 | 
            +
                    act_fn: str = "silu",
         | 
| 90 | 
            +
                    double_z: bool = True,
         | 
| 91 | 
            +
                    block_dropout: Tuple[int, ...] = (0.0,),
         | 
| 92 | 
            +
                    mid_block_add_attention=True,
         | 
| 93 | 
            +
                ):
         | 
| 94 | 
            +
                    super().__init__()
         | 
| 95 | 
            +
                    self.layers_per_block = layers_per_block
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    self.conv_in = CausalConv3d(
         | 
| 98 | 
            +
                        in_channels,
         | 
| 99 | 
            +
                        block_out_channels[0],
         | 
| 100 | 
            +
                        kernel_size=3,
         | 
| 101 | 
            +
                        stride=1,
         | 
| 102 | 
            +
                    )
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    self.mid_block = None
         | 
| 105 | 
            +
                    self.down_blocks = nn.ModuleList([])
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    # down
         | 
| 108 | 
            +
                    output_channel = block_out_channels[0]
         | 
| 109 | 
            +
                    for i, down_block_type in enumerate(down_block_types):
         | 
| 110 | 
            +
                        input_channel = output_channel
         | 
| 111 | 
            +
                        output_channel = block_out_channels[i]
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                        down_block = get_down_block(
         | 
| 114 | 
            +
                            down_block_type,
         | 
| 115 | 
            +
                            num_layers=self.layers_per_block[i],
         | 
| 116 | 
            +
                            in_channels=input_channel,
         | 
| 117 | 
            +
                            out_channels=output_channel,
         | 
| 118 | 
            +
                            add_spatial_downsample=spatial_down_sample[i],
         | 
| 119 | 
            +
                            add_temporal_downsample=temporal_down_sample[i],
         | 
| 120 | 
            +
                            resnet_eps=1e-6,
         | 
| 121 | 
            +
                            downsample_padding=0,
         | 
| 122 | 
            +
                            resnet_act_fn=act_fn,
         | 
| 123 | 
            +
                            resnet_groups=norm_num_groups,
         | 
| 124 | 
            +
                            attention_head_dim=output_channel,
         | 
| 125 | 
            +
                            temb_channels=None,
         | 
| 126 | 
            +
                            dropout=block_dropout[i],
         | 
| 127 | 
            +
                        )
         | 
| 128 | 
            +
                        self.down_blocks.append(down_block)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    # mid
         | 
| 131 | 
            +
                    self.mid_block = CausalUNetMidBlock2D(
         | 
| 132 | 
            +
                        in_channels=block_out_channels[-1],
         | 
| 133 | 
            +
                        resnet_eps=1e-6,
         | 
| 134 | 
            +
                        resnet_act_fn=act_fn,
         | 
| 135 | 
            +
                        output_scale_factor=1,
         | 
| 136 | 
            +
                        resnet_time_scale_shift="default",
         | 
| 137 | 
            +
                        attention_head_dim=block_out_channels[-1],
         | 
| 138 | 
            +
                        resnet_groups=norm_num_groups,
         | 
| 139 | 
            +
                        temb_channels=None,
         | 
| 140 | 
            +
                        add_attention=mid_block_add_attention,
         | 
| 141 | 
            +
                        dropout=block_dropout[-1],
         | 
| 142 | 
            +
                    )
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    # out
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    self.conv_norm_out = CausalGroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
         | 
| 147 | 
            +
                    self.conv_act = nn.SiLU()
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    conv_out_channels = 2 * out_channels if double_z else out_channels
         | 
| 150 | 
            +
                    self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3, stride=1)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    self.gradient_checkpointing = False
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                def forward(self, sample: torch.FloatTensor, is_init_image=True, temporal_chunk=False) -> torch.FloatTensor:
         | 
| 155 | 
            +
                    r"""The forward method of the `Encoder` class."""
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    sample = self.conv_in(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    if self.training and self.gradient_checkpointing:
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                        def create_custom_forward(module):
         | 
| 162 | 
            +
                            def custom_forward(*inputs):
         | 
| 163 | 
            +
                                return module(*inputs)
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                            return custom_forward
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                        # down
         | 
| 168 | 
            +
                        if is_torch_version(">=", "1.11.0"):
         | 
| 169 | 
            +
                            for down_block in self.down_blocks:
         | 
| 170 | 
            +
                                sample = torch.utils.checkpoint.checkpoint(
         | 
| 171 | 
            +
                                    create_custom_forward(down_block), sample, is_init_image, 
         | 
| 172 | 
            +
                                        temporal_chunk, use_reentrant=False
         | 
| 173 | 
            +
                                )
         | 
| 174 | 
            +
                            # middle
         | 
| 175 | 
            +
                            sample = torch.utils.checkpoint.checkpoint(
         | 
| 176 | 
            +
                                create_custom_forward(self.mid_block), sample, is_init_image, 
         | 
| 177 | 
            +
                                    temporal_chunk, use_reentrant=False
         | 
| 178 | 
            +
                            )
         | 
| 179 | 
            +
                        else:
         | 
| 180 | 
            +
                            for down_block in self.down_blocks:
         | 
| 181 | 
            +
                                sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample, is_init_image, temporal_chunk)
         | 
| 182 | 
            +
                            # middle
         | 
| 183 | 
            +
                            sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, is_init_image, temporal_chunk)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                    else:
         | 
| 186 | 
            +
                        # down
         | 
| 187 | 
            +
                        for down_block in self.down_blocks:
         | 
| 188 | 
            +
                            sample = down_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                        # middle
         | 
| 191 | 
            +
                        sample = self.mid_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    # post-process
         | 
| 194 | 
            +
                    sample = self.conv_norm_out(sample)
         | 
| 195 | 
            +
                    sample = self.conv_act(sample)
         | 
| 196 | 
            +
                    sample = self.conv_out(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                    return sample
         | 
| 199 | 
            +
             | 
| 200 | 
            +
             | 
| 201 | 
            +
            class CausalVaeDecoder(nn.Module):
         | 
| 202 | 
            +
                r"""
         | 
| 203 | 
            +
                The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                Args:
         | 
| 206 | 
            +
                    in_channels (`int`, *optional*, defaults to 3):
         | 
| 207 | 
            +
                        The number of input channels.
         | 
| 208 | 
            +
                    out_channels (`int`, *optional*, defaults to 3):
         | 
| 209 | 
            +
                        The number of output channels.
         | 
| 210 | 
            +
                    up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
         | 
| 211 | 
            +
                        The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
         | 
| 212 | 
            +
                    block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
         | 
| 213 | 
            +
                        The number of output channels for each block.
         | 
| 214 | 
            +
                    layers_per_block (`int`, *optional*, defaults to 2):
         | 
| 215 | 
            +
                        The number of layers per block.
         | 
| 216 | 
            +
                    norm_num_groups (`int`, *optional*, defaults to 32):
         | 
| 217 | 
            +
                        The number of groups for normalization.
         | 
| 218 | 
            +
                    act_fn (`str`, *optional*, defaults to `"silu"`):
         | 
| 219 | 
            +
                        The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
         | 
| 220 | 
            +
                    norm_type (`str`, *optional*, defaults to `"group"`):
         | 
| 221 | 
            +
                        The normalization type to use. Can be either `"group"` or `"spatial"`.
         | 
| 222 | 
            +
                """
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                def __init__(
         | 
| 225 | 
            +
                    self,
         | 
| 226 | 
            +
                    in_channels: int = 3,
         | 
| 227 | 
            +
                    out_channels: int = 3,
         | 
| 228 | 
            +
                    up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
         | 
| 229 | 
            +
                    spatial_up_sample: Tuple[bool, ...] = (True,),
         | 
| 230 | 
            +
                    temporal_up_sample: Tuple[bool, ...] = (False,),
         | 
| 231 | 
            +
                    block_out_channels: Tuple[int, ...] = (64,),
         | 
| 232 | 
            +
                    layers_per_block: Tuple[int, ...] = (2,),
         | 
| 233 | 
            +
                    norm_num_groups: int = 32,
         | 
| 234 | 
            +
                    act_fn: str = "silu",
         | 
| 235 | 
            +
                    mid_block_add_attention=True,
         | 
| 236 | 
            +
                    interpolate: bool = True,
         | 
| 237 | 
            +
                    block_dropout: Tuple[int, ...] = (0.0,),
         | 
| 238 | 
            +
                ):
         | 
| 239 | 
            +
                    super().__init__()
         | 
| 240 | 
            +
                    self.layers_per_block = layers_per_block
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    self.conv_in = CausalConv3d(
         | 
| 243 | 
            +
                        in_channels,
         | 
| 244 | 
            +
                        block_out_channels[-1],
         | 
| 245 | 
            +
                        kernel_size=3,
         | 
| 246 | 
            +
                        stride=1,
         | 
| 247 | 
            +
                    )
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                    self.mid_block = None
         | 
| 250 | 
            +
                    self.up_blocks = nn.ModuleList([])
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                    # mid
         | 
| 253 | 
            +
                    self.mid_block = CausalUNetMidBlock2D(
         | 
| 254 | 
            +
                        in_channels=block_out_channels[-1],
         | 
| 255 | 
            +
                        resnet_eps=1e-6,
         | 
| 256 | 
            +
                        resnet_act_fn=act_fn,
         | 
| 257 | 
            +
                        output_scale_factor=1,
         | 
| 258 | 
            +
                        resnet_time_scale_shift="default",
         | 
| 259 | 
            +
                        attention_head_dim=block_out_channels[-1],
         | 
| 260 | 
            +
                        resnet_groups=norm_num_groups,
         | 
| 261 | 
            +
                        temb_channels=None,
         | 
| 262 | 
            +
                        add_attention=mid_block_add_attention,
         | 
| 263 | 
            +
                        dropout=block_dropout[-1],
         | 
| 264 | 
            +
                    )
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                    # up
         | 
| 267 | 
            +
                    reversed_block_out_channels = list(reversed(block_out_channels))
         | 
| 268 | 
            +
                    output_channel = reversed_block_out_channels[0]
         | 
| 269 | 
            +
                    for i, up_block_type in enumerate(up_block_types):
         | 
| 270 | 
            +
                        prev_output_channel = output_channel
         | 
| 271 | 
            +
                        output_channel = reversed_block_out_channels[i]
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                        is_final_block = i == len(block_out_channels) - 1
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                        up_block = get_up_block(
         | 
| 276 | 
            +
                            up_block_type,
         | 
| 277 | 
            +
                            num_layers=self.layers_per_block[i],
         | 
| 278 | 
            +
                            in_channels=prev_output_channel,
         | 
| 279 | 
            +
                            out_channels=output_channel,
         | 
| 280 | 
            +
                            prev_output_channel=None,
         | 
| 281 | 
            +
                            add_spatial_upsample=spatial_up_sample[i],
         | 
| 282 | 
            +
                            add_temporal_upsample=temporal_up_sample[i],
         | 
| 283 | 
            +
                            resnet_eps=1e-6,
         | 
| 284 | 
            +
                            resnet_act_fn=act_fn,
         | 
| 285 | 
            +
                            resnet_groups=norm_num_groups,
         | 
| 286 | 
            +
                            attention_head_dim=output_channel,
         | 
| 287 | 
            +
                            temb_channels=None,
         | 
| 288 | 
            +
                            resnet_time_scale_shift='default',
         | 
| 289 | 
            +
                            interpolate=interpolate,
         | 
| 290 | 
            +
                            dropout=block_dropout[i],
         | 
| 291 | 
            +
                        )
         | 
| 292 | 
            +
                        self.up_blocks.append(up_block)
         | 
| 293 | 
            +
                        prev_output_channel = output_channel
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                    # out
         | 
| 296 | 
            +
                    self.conv_norm_out = CausalGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
         | 
| 297 | 
            +
                    self.conv_act = nn.SiLU()
         | 
| 298 | 
            +
                    self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3, stride=1)
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                    self.gradient_checkpointing = False
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                def forward(
         | 
| 303 | 
            +
                    self,
         | 
| 304 | 
            +
                    sample: torch.FloatTensor,
         | 
| 305 | 
            +
                    is_init_image=True, 
         | 
| 306 | 
            +
                    temporal_chunk=False,
         | 
| 307 | 
            +
                ) -> torch.FloatTensor:
         | 
| 308 | 
            +
                    r"""The forward method of the `Decoder` class."""
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                    sample = self.conv_in(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                    upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
         | 
| 313 | 
            +
                    if self.training and self.gradient_checkpointing:
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                        def create_custom_forward(module):
         | 
| 316 | 
            +
                            def custom_forward(*inputs):
         | 
| 317 | 
            +
                                return module(*inputs)
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                            return custom_forward
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                        if is_torch_version(">=", "1.11.0"):
         | 
| 322 | 
            +
                            # middle
         | 
| 323 | 
            +
                            sample = torch.utils.checkpoint.checkpoint(
         | 
| 324 | 
            +
                                create_custom_forward(self.mid_block),
         | 
| 325 | 
            +
                                sample,
         | 
| 326 | 
            +
                                is_init_image=is_init_image,
         | 
| 327 | 
            +
                                temporal_chunk=temporal_chunk,
         | 
| 328 | 
            +
                                use_reentrant=False,
         | 
| 329 | 
            +
                            )
         | 
| 330 | 
            +
                            sample = sample.to(upscale_dtype)
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                            # up
         | 
| 333 | 
            +
                            for up_block in self.up_blocks:
         | 
| 334 | 
            +
                                sample = torch.utils.checkpoint.checkpoint(
         | 
| 335 | 
            +
                                    create_custom_forward(up_block),
         | 
| 336 | 
            +
                                    sample,
         | 
| 337 | 
            +
                                    is_init_image=is_init_image,
         | 
| 338 | 
            +
                                    temporal_chunk=temporal_chunk,
         | 
| 339 | 
            +
                                    use_reentrant=False,
         | 
| 340 | 
            +
                                )
         | 
| 341 | 
            +
                        else:
         | 
| 342 | 
            +
                            # middle
         | 
| 343 | 
            +
                            sample = torch.utils.checkpoint.checkpoint(
         | 
| 344 | 
            +
                                create_custom_forward(self.mid_block), sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk,
         | 
| 345 | 
            +
                            )
         | 
| 346 | 
            +
                            sample = sample.to(upscale_dtype)
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                            # up
         | 
| 349 | 
            +
                            for up_block in self.up_blocks:
         | 
| 350 | 
            +
                                sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, 
         | 
| 351 | 
            +
                                    is_init_image=is_init_image, temporal_chunk=temporal_chunk,)
         | 
| 352 | 
            +
                    else:
         | 
| 353 | 
            +
                        # middle
         | 
| 354 | 
            +
                        sample = self.mid_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
         | 
| 355 | 
            +
                        sample = sample.to(upscale_dtype)
         | 
| 356 | 
            +
                        
         | 
| 357 | 
            +
                        # up
         | 
| 358 | 
            +
                        for up_block in self.up_blocks:
         | 
| 359 | 
            +
                            sample = up_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk,)
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                    # post-process
         | 
| 362 | 
            +
                    sample = self.conv_norm_out(sample)
         | 
| 363 | 
            +
                    sample = self.conv_act(sample)
         | 
| 364 | 
            +
                    sample = self.conv_out(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                    return sample
         | 
| 367 | 
            +
             | 
| 368 | 
            +
             | 
| 369 | 
            +
            class DiagonalGaussianDistribution(object):
         | 
| 370 | 
            +
                def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
         | 
| 371 | 
            +
                    self.parameters = parameters
         | 
| 372 | 
            +
                    self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
         | 
| 373 | 
            +
                    self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
         | 
| 374 | 
            +
                    self.deterministic = deterministic
         | 
| 375 | 
            +
                    self.std = torch.exp(0.5 * self.logvar)
         | 
| 376 | 
            +
                    self.var = torch.exp(self.logvar)
         | 
| 377 | 
            +
                    if self.deterministic:
         | 
| 378 | 
            +
                        self.var = self.std = torch.zeros_like(
         | 
| 379 | 
            +
                            self.mean, device=self.parameters.device, dtype=self.parameters.dtype
         | 
| 380 | 
            +
                        )
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
         | 
| 383 | 
            +
                    # make sure sample is on the same device as the parameters and has same dtype
         | 
| 384 | 
            +
                    sample = randn_tensor(
         | 
| 385 | 
            +
                        self.mean.shape,
         | 
| 386 | 
            +
                        generator=generator,
         | 
| 387 | 
            +
                        device=self.parameters.device,
         | 
| 388 | 
            +
                        dtype=self.parameters.dtype,
         | 
| 389 | 
            +
                    )
         | 
| 390 | 
            +
                    x = self.mean + self.std * sample
         | 
| 391 | 
            +
                    return x
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
         | 
| 394 | 
            +
                    if self.deterministic:
         | 
| 395 | 
            +
                        return torch.Tensor([0.0])
         | 
| 396 | 
            +
                    else:
         | 
| 397 | 
            +
                        if other is None:
         | 
| 398 | 
            +
                            return 0.5 * torch.sum(
         | 
| 399 | 
            +
                                torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
         | 
| 400 | 
            +
                                dim=[2, 3, 4],
         | 
| 401 | 
            +
                            )
         | 
| 402 | 
            +
                        else:
         | 
| 403 | 
            +
                            return 0.5 * torch.sum(
         | 
| 404 | 
            +
                                torch.pow(self.mean - other.mean, 2) / other.var
         | 
| 405 | 
            +
                                + self.var / other.var
         | 
| 406 | 
            +
                                - 1.0
         | 
| 407 | 
            +
                                - self.logvar
         | 
| 408 | 
            +
                                + other.logvar,
         | 
| 409 | 
            +
                                dim=[2, 3, 4],
         | 
| 410 | 
            +
                            )
         | 
| 411 | 
            +
             | 
| 412 | 
            +
                def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
         | 
| 413 | 
            +
                    if self.deterministic:
         | 
| 414 | 
            +
                        return torch.Tensor([0.0])
         | 
| 415 | 
            +
                    logtwopi = np.log(2.0 * np.pi)
         | 
| 416 | 
            +
                    return 0.5 * torch.sum(
         | 
| 417 | 
            +
                        logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
         | 
| 418 | 
            +
                        dim=dims,
         | 
| 419 | 
            +
                    )
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                def mode(self) -> torch.Tensor:
         | 
| 422 | 
            +
                    return self.mean
         | 
    	
        video_vae/modeling_loss.py
    ADDED
    
    | @@ -0,0 +1,192 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            from torch import nn
         | 
| 4 | 
            +
            import torch.nn.functional as F
         | 
| 5 | 
            +
            from einops import rearrange
         | 
| 6 | 
            +
            from .modeling_lpips import LPIPS
         | 
| 7 | 
            +
            from .modeling_discriminator import NLayerDiscriminator, NLayerDiscriminator3D, weights_init
         | 
| 8 | 
            +
            from IPython import embed
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            class AdaptiveLossWeight:
         | 
| 12 | 
            +
                def __init__(self, timestep_range=[0, 1], buckets=300, weight_range=[1e-7, 1e7]):
         | 
| 13 | 
            +
                    self.bucket_ranges = torch.linspace(timestep_range[0], timestep_range[1], buckets-1)
         | 
| 14 | 
            +
                    self.bucket_losses = torch.ones(buckets)
         | 
| 15 | 
            +
                    self.weight_range = weight_range
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def weight(self, timestep):
         | 
| 18 | 
            +
                    indices = torch.searchsorted(self.bucket_ranges.to(timestep.device), timestep)
         | 
| 19 | 
            +
                    return (1/self.bucket_losses.to(timestep.device)[indices]).clamp(*self.weight_range)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def update_buckets(self, timestep, loss, beta=0.99):
         | 
| 22 | 
            +
                    indices = torch.searchsorted(self.bucket_ranges.to(timestep.device), timestep).cpu()
         | 
| 23 | 
            +
                    self.bucket_losses[indices] = self.bucket_losses[indices]*beta + loss.detach().cpu() * (1-beta)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            def hinge_d_loss(logits_real, logits_fake):
         | 
| 27 | 
            +
                loss_real = torch.mean(F.relu(1.0 - logits_real))
         | 
| 28 | 
            +
                loss_fake = torch.mean(F.relu(1.0 + logits_fake))
         | 
| 29 | 
            +
                d_loss = 0.5 * (loss_real + loss_fake)
         | 
| 30 | 
            +
                return d_loss
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            def vanilla_d_loss(logits_real, logits_fake):
         | 
| 34 | 
            +
                d_loss = 0.5 * (
         | 
| 35 | 
            +
                    torch.mean(torch.nn.functional.softplus(-logits_real))
         | 
| 36 | 
            +
                    + torch.mean(torch.nn.functional.softplus(logits_fake))
         | 
| 37 | 
            +
                )
         | 
| 38 | 
            +
                return d_loss
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            def adopt_weight(weight, global_step, threshold=0, value=0.0):
         | 
| 42 | 
            +
                if global_step < threshold:
         | 
| 43 | 
            +
                    weight = value
         | 
| 44 | 
            +
                return weight
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            class LPIPSWithDiscriminator(nn.Module):
         | 
| 48 | 
            +
                def __init__(
         | 
| 49 | 
            +
                    self,
         | 
| 50 | 
            +
                    disc_start,
         | 
| 51 | 
            +
                    logvar_init=0.0,
         | 
| 52 | 
            +
                    kl_weight=1.0,
         | 
| 53 | 
            +
                    pixelloss_weight=1.0,
         | 
| 54 | 
            +
                    perceptual_weight=1.0,
         | 
| 55 | 
            +
                    # --- Discriminator Loss ---
         | 
| 56 | 
            +
                    disc_num_layers=4,
         | 
| 57 | 
            +
                    disc_in_channels=3,
         | 
| 58 | 
            +
                    disc_factor=1.0,
         | 
| 59 | 
            +
                    disc_weight=0.5,
         | 
| 60 | 
            +
                    disc_loss="hinge",
         | 
| 61 | 
            +
                    add_discriminator=True,
         | 
| 62 | 
            +
                    using_3d_discriminator=False,
         | 
| 63 | 
            +
                ):
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    super().__init__()
         | 
| 66 | 
            +
                    assert disc_loss in ["hinge", "vanilla"]
         | 
| 67 | 
            +
                    self.kl_weight = kl_weight
         | 
| 68 | 
            +
                    self.pixel_weight = pixelloss_weight
         | 
| 69 | 
            +
                    self.perceptual_loss = LPIPS().eval()
         | 
| 70 | 
            +
                    self.perceptual_weight = perceptual_weight
         | 
| 71 | 
            +
                    self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    if add_discriminator:
         | 
| 74 | 
            +
                        disc_cls = NLayerDiscriminator3D if using_3d_discriminator else NLayerDiscriminator
         | 
| 75 | 
            +
                        self.discriminator = disc_cls(
         | 
| 76 | 
            +
                            input_nc=disc_in_channels, n_layers=disc_num_layers,
         | 
| 77 | 
            +
                        ).apply(weights_init)
         | 
| 78 | 
            +
                    else:
         | 
| 79 | 
            +
                        self.discriminator = None
         | 
| 80 | 
            +
                
         | 
| 81 | 
            +
                    self.discriminator_iter_start = disc_start
         | 
| 82 | 
            +
                    self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
         | 
| 83 | 
            +
                    self.disc_factor = disc_factor
         | 
| 84 | 
            +
                    self.discriminator_weight = disc_weight
         | 
| 85 | 
            +
                    self.using_3d_discriminator = using_3d_discriminator
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
         | 
| 88 | 
            +
                    if last_layer is not None:
         | 
| 89 | 
            +
                        nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
         | 
| 90 | 
            +
                        g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
         | 
| 91 | 
            +
                    else:
         | 
| 92 | 
            +
                        nll_grads = torch.autograd.grad(
         | 
| 93 | 
            +
                            nll_loss, self.last_layer[0], retain_graph=True
         | 
| 94 | 
            +
                        )[0]
         | 
| 95 | 
            +
                        g_grads = torch.autograd.grad(
         | 
| 96 | 
            +
                            g_loss, self.last_layer[0], retain_graph=True
         | 
| 97 | 
            +
                        )[0]
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
         | 
| 100 | 
            +
                    d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
         | 
| 101 | 
            +
                    d_weight = d_weight * self.discriminator_weight
         | 
| 102 | 
            +
                    return d_weight
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                def forward(
         | 
| 105 | 
            +
                    self,
         | 
| 106 | 
            +
                    inputs,
         | 
| 107 | 
            +
                    reconstructions,
         | 
| 108 | 
            +
                    posteriors,
         | 
| 109 | 
            +
                    optimizer_idx,
         | 
| 110 | 
            +
                    global_step,
         | 
| 111 | 
            +
                    split="train",
         | 
| 112 | 
            +
                    last_layer=None,
         | 
| 113 | 
            +
                ):
         | 
| 114 | 
            +
                    t = reconstructions.shape[2]
         | 
| 115 | 
            +
                    inputs = rearrange(inputs, "b c t h w -> (b t) c h w").contiguous()
         | 
| 116 | 
            +
                    reconstructions = rearrange(reconstructions, "b c t h w -> (b t) c h w").contiguous()
         | 
| 117 | 
            +
                
         | 
| 118 | 
            +
                    if optimizer_idx == 0:
         | 
| 119 | 
            +
                        # rec_loss = torch.mean(torch.abs(inputs - reconstructions), dim=(1,2,3), keepdim=True)
         | 
| 120 | 
            +
                        rec_loss = torch.mean(F.mse_loss(inputs, reconstructions, reduction='none'), dim=(1,2,3), keepdim=True)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                        if self.perceptual_weight > 0:
         | 
| 123 | 
            +
                            p_loss = self.perceptual_loss(inputs, reconstructions)
         | 
| 124 | 
            +
                            nll_loss = self.pixel_weight * rec_loss + self.perceptual_weight * p_loss
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                        nll_loss = nll_loss / torch.exp(self.logvar) + self.logvar
         | 
| 127 | 
            +
                        weighted_nll_loss = nll_loss
         | 
| 128 | 
            +
                        weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
         | 
| 129 | 
            +
                        nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                        kl_loss = posteriors.kl()
         | 
| 132 | 
            +
                        kl_loss = torch.mean(kl_loss)
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                        disc_factor = adopt_weight(
         | 
| 135 | 
            +
                            self.disc_factor, global_step, threshold=self.discriminator_iter_start
         | 
| 136 | 
            +
                        )
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                        if disc_factor > 0.0:
         | 
| 139 | 
            +
                            if self.using_3d_discriminator:
         | 
| 140 | 
            +
                                reconstructions = rearrange(reconstructions, '(b t) c h w -> b c t h w', t=t)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                            logits_fake = self.discriminator(reconstructions.contiguous())
         | 
| 143 | 
            +
                            g_loss = -torch.mean(logits_fake)
         | 
| 144 | 
            +
                            try:
         | 
| 145 | 
            +
                                d_weight = self.calculate_adaptive_weight(
         | 
| 146 | 
            +
                                    nll_loss, g_loss, last_layer=last_layer
         | 
| 147 | 
            +
                                )
         | 
| 148 | 
            +
                            except RuntimeError:
         | 
| 149 | 
            +
                                assert not self.training
         | 
| 150 | 
            +
                                d_weight = torch.tensor(0.0)
         | 
| 151 | 
            +
                        else:
         | 
| 152 | 
            +
                            d_weight = torch.tensor(0.0)
         | 
| 153 | 
            +
                            g_loss = torch.tensor(0.0)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                        
         | 
| 156 | 
            +
                        loss = (
         | 
| 157 | 
            +
                            weighted_nll_loss
         | 
| 158 | 
            +
                            + self.kl_weight * kl_loss
         | 
| 159 | 
            +
                            + d_weight * disc_factor * g_loss
         | 
| 160 | 
            +
                        )
         | 
| 161 | 
            +
                        log = {
         | 
| 162 | 
            +
                            "{}/total_loss".format(split): loss.clone().detach().mean(),
         | 
| 163 | 
            +
                            "{}/logvar".format(split): self.logvar.detach(),
         | 
| 164 | 
            +
                            "{}/kl_loss".format(split): kl_loss.detach().mean(),
         | 
| 165 | 
            +
                            "{}/nll_loss".format(split): nll_loss.detach().mean(),
         | 
| 166 | 
            +
                            "{}/rec_loss".format(split): rec_loss.detach().mean(),
         | 
| 167 | 
            +
                            "{}/perception_loss".format(split): p_loss.detach().mean(),
         | 
| 168 | 
            +
                            "{}/d_weight".format(split): d_weight.detach(),
         | 
| 169 | 
            +
                            "{}/disc_factor".format(split): torch.tensor(disc_factor),
         | 
| 170 | 
            +
                            "{}/g_loss".format(split): g_loss.detach().mean(),
         | 
| 171 | 
            +
                        }
         | 
| 172 | 
            +
                        return loss, log
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    if optimizer_idx == 1:
         | 
| 175 | 
            +
                        if self.using_3d_discriminator:
         | 
| 176 | 
            +
                            inputs = rearrange(inputs, '(b t) c h w -> b c t h w', t=t)
         | 
| 177 | 
            +
                            reconstructions = rearrange(reconstructions, '(b t) c h w -> b c t h w', t=t)
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                        logits_real = self.discriminator(inputs.contiguous().detach())
         | 
| 180 | 
            +
                        logits_fake = self.discriminator(reconstructions.contiguous().detach())
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                        disc_factor = adopt_weight(
         | 
| 183 | 
            +
                            self.disc_factor, global_step, threshold=self.discriminator_iter_start
         | 
| 184 | 
            +
                        )
         | 
| 185 | 
            +
                        d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                        log = {
         | 
| 188 | 
            +
                            "{}/disc_loss".format(split): d_loss.clone().detach().mean(),
         | 
| 189 | 
            +
                            "{}/logits_real".format(split): logits_real.detach().mean(),
         | 
| 190 | 
            +
                            "{}/logits_fake".format(split): logits_fake.detach().mean(),
         | 
| 191 | 
            +
                        }
         | 
| 192 | 
            +
                        return d_loss, log
         | 
    	
        video_vae/modeling_lpips.py
    ADDED
    
    | @@ -0,0 +1,120 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            from torchvision import models
         | 
| 6 | 
            +
            from collections import namedtuple
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            class LPIPS(nn.Module):
         | 
| 10 | 
            +
                # Learned perceptual metric
         | 
| 11 | 
            +
                def __init__(self, use_dropout=True):
         | 
| 12 | 
            +
                    super().__init__()
         | 
| 13 | 
            +
                    self.scaling_layer = ScalingLayer()
         | 
| 14 | 
            +
                    self.chns = [64, 128, 256, 512, 512]  # vg16 features
         | 
| 15 | 
            +
                    self.net = vgg16(pretrained=False, requires_grad=False)
         | 
| 16 | 
            +
                    self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
         | 
| 17 | 
            +
                    self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
         | 
| 18 | 
            +
                    self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
         | 
| 19 | 
            +
                    self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
         | 
| 20 | 
            +
                    self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
         | 
| 21 | 
            +
                    self.load_from_pretrained()
         | 
| 22 | 
            +
                    for param in self.parameters():
         | 
| 23 | 
            +
                        param.requires_grad = False
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def load_from_pretrained(self):
         | 
| 26 | 
            +
                    ckpt = "/home/jinyang/models/vae/video_vae_baseline/vgg_lpips.pth"    # replace with your lpips
         | 
| 27 | 
            +
                    self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=True)
         | 
| 28 | 
            +
                    print("loaded pretrained LPIPS loss from {}".format(ckpt))
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                def forward(self, input, target):
         | 
| 31 | 
            +
                    in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
         | 
| 32 | 
            +
                    outs0, outs1 = self.net(in0_input), self.net(in1_input)
         | 
| 33 | 
            +
                    feats0, feats1, diffs = {}, {}, {}
         | 
| 34 | 
            +
                    lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
         | 
| 35 | 
            +
                    for kk in range(len(self.chns)):
         | 
| 36 | 
            +
                        feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
         | 
| 37 | 
            +
                        diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
         | 
| 40 | 
            +
                    val = res[0]
         | 
| 41 | 
            +
                    for l in range(1, len(self.chns)):
         | 
| 42 | 
            +
                        val += res[l]
         | 
| 43 | 
            +
                    return val
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            class ScalingLayer(nn.Module):
         | 
| 47 | 
            +
                def __init__(self):
         | 
| 48 | 
            +
                    super(ScalingLayer, self).__init__()
         | 
| 49 | 
            +
                    self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
         | 
| 50 | 
            +
                    self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def forward(self, inp):
         | 
| 53 | 
            +
                    return (inp - self.shift) / self.scale
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            class NetLinLayer(nn.Module):
         | 
| 57 | 
            +
                """ A single linear layer which does a 1x1 conv """
         | 
| 58 | 
            +
                def __init__(self, chn_in, chn_out=1, use_dropout=False):
         | 
| 59 | 
            +
                    super(NetLinLayer, self).__init__()
         | 
| 60 | 
            +
                    layers = [nn.Dropout(), ] if (use_dropout) else []
         | 
| 61 | 
            +
                    layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
         | 
| 62 | 
            +
                    self.model = nn.Sequential(*layers)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            class vgg16(torch.nn.Module):
         | 
| 66 | 
            +
                def __init__(self, requires_grad=False, pretrained=True):
         | 
| 67 | 
            +
                    super(vgg16, self).__init__()
         | 
| 68 | 
            +
                    vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
         | 
| 69 | 
            +
                    self.slice1 = torch.nn.Sequential()
         | 
| 70 | 
            +
                    self.slice2 = torch.nn.Sequential()
         | 
| 71 | 
            +
                    self.slice3 = torch.nn.Sequential()
         | 
| 72 | 
            +
                    self.slice4 = torch.nn.Sequential()
         | 
| 73 | 
            +
                    self.slice5 = torch.nn.Sequential()
         | 
| 74 | 
            +
                    self.N_slices = 5
         | 
| 75 | 
            +
                    for x in range(4):
         | 
| 76 | 
            +
                        self.slice1.add_module(str(x), vgg_pretrained_features[x])
         | 
| 77 | 
            +
                    for x in range(4, 9):
         | 
| 78 | 
            +
                        self.slice2.add_module(str(x), vgg_pretrained_features[x])
         | 
| 79 | 
            +
                    for x in range(9, 16):
         | 
| 80 | 
            +
                        self.slice3.add_module(str(x), vgg_pretrained_features[x])
         | 
| 81 | 
            +
                    for x in range(16, 23):
         | 
| 82 | 
            +
                        self.slice4.add_module(str(x), vgg_pretrained_features[x])
         | 
| 83 | 
            +
                    for x in range(23, 30):
         | 
| 84 | 
            +
                        self.slice5.add_module(str(x), vgg_pretrained_features[x])
         | 
| 85 | 
            +
                    if not requires_grad:
         | 
| 86 | 
            +
                        for param in self.parameters():
         | 
| 87 | 
            +
                            param.requires_grad = False
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                def forward(self, X):
         | 
| 90 | 
            +
                    h = self.slice1(X)
         | 
| 91 | 
            +
                    h_relu1_2 = h
         | 
| 92 | 
            +
                    h = self.slice2(h)
         | 
| 93 | 
            +
                    h_relu2_2 = h
         | 
| 94 | 
            +
                    h = self.slice3(h)
         | 
| 95 | 
            +
                    h_relu3_3 = h
         | 
| 96 | 
            +
                    h = self.slice4(h)
         | 
| 97 | 
            +
                    h_relu4_3 = h
         | 
| 98 | 
            +
                    h = self.slice5(h)
         | 
| 99 | 
            +
                    h_relu5_3 = h
         | 
| 100 | 
            +
                    vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
         | 
| 101 | 
            +
                    out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
         | 
| 102 | 
            +
                    return out
         | 
| 103 | 
            +
             | 
| 104 | 
            +
             | 
| 105 | 
            +
            def normalize_tensor(x,eps=1e-10):
         | 
| 106 | 
            +
                norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
         | 
| 107 | 
            +
                return x/(norm_factor+eps)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
             | 
| 110 | 
            +
            def spatial_average(x, keepdim=True):
         | 
| 111 | 
            +
                return x.mean([2,3],keepdim=keepdim)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
             | 
| 114 | 
            +
            if __name__ == "__main__":
         | 
| 115 | 
            +
                model = LPIPS().eval()
         | 
| 116 | 
            +
                _ = torch.manual_seed(123)
         | 
| 117 | 
            +
                img1 = (torch.rand(10, 3, 100, 100) * 2) - 1
         | 
| 118 | 
            +
                img2 = (torch.rand(10, 3, 100, 100) * 2) - 1
         | 
| 119 | 
            +
                print(model(img1, img2).shape)
         | 
| 120 | 
            +
                # embed()
         | 
    	
        video_vae/modeling_resnet.py
    ADDED
    
    | @@ -0,0 +1,729 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from functools import partial
         | 
| 2 | 
            +
            from typing import Optional, Tuple, Union
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn as nn
         | 
| 6 | 
            +
            import torch.nn.functional as F
         | 
| 7 | 
            +
            from einops import rearrange
         | 
| 8 | 
            +
            from diffusers.models.activations import get_activation
         | 
| 9 | 
            +
            from diffusers.models.attention_processor import SpatialNorm
         | 
| 10 | 
            +
            from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
         | 
| 11 | 
            +
            from diffusers.models.normalization import AdaGroupNorm
         | 
| 12 | 
            +
            from timm.models.layers import drop_path, to_2tuple, trunc_normal_
         | 
| 13 | 
            +
            from .modeling_causal_conv import CausalConv3d, CausalGroupNorm
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            class CausalResnetBlock3D(nn.Module):
         | 
| 17 | 
            +
                r"""
         | 
| 18 | 
            +
                A Resnet block.
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                Parameters:
         | 
| 21 | 
            +
                    in_channels (`int`): The number of channels in the input.
         | 
| 22 | 
            +
                    out_channels (`int`, *optional*, default to be `None`):
         | 
| 23 | 
            +
                        The number of output channels for the first conv2d layer. If None, same as `in_channels`.
         | 
| 24 | 
            +
                    dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
         | 
| 25 | 
            +
                    temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
         | 
| 26 | 
            +
                    groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
         | 
| 27 | 
            +
                    groups_out (`int`, *optional*, default to None):
         | 
| 28 | 
            +
                        The number of groups to use for the second normalization layer. if set to None, same as `groups`.
         | 
| 29 | 
            +
                    eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
         | 
| 30 | 
            +
                    non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
         | 
| 31 | 
            +
                    time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
         | 
| 32 | 
            +
                        By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
         | 
| 33 | 
            +
                        "ada_group" for a stronger conditioning with scale and shift.
         | 
| 34 | 
            +
                    kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
         | 
| 35 | 
            +
                        [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
         | 
| 36 | 
            +
                    output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
         | 
| 37 | 
            +
                    use_in_shortcut (`bool`, *optional*, default to `True`):
         | 
| 38 | 
            +
                        If `True`, add a 1x1 nn.conv2d layer for skip-connection.
         | 
| 39 | 
            +
                    up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
         | 
| 40 | 
            +
                    down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
         | 
| 41 | 
            +
                    conv_shortcut_bias (`bool`, *optional*, default to `True`):  If `True`, adds a learnable bias to the
         | 
| 42 | 
            +
                        `conv_shortcut` output.
         | 
| 43 | 
            +
                    conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
         | 
| 44 | 
            +
                        If None, same as `out_channels`.
         | 
| 45 | 
            +
                """
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                def __init__(
         | 
| 48 | 
            +
                    self,
         | 
| 49 | 
            +
                    *,
         | 
| 50 | 
            +
                    in_channels: int,
         | 
| 51 | 
            +
                    out_channels: Optional[int] = None,
         | 
| 52 | 
            +
                    conv_shortcut: bool = False,
         | 
| 53 | 
            +
                    dropout: float = 0.0,
         | 
| 54 | 
            +
                    temb_channels: int = 512,
         | 
| 55 | 
            +
                    groups: int = 32,
         | 
| 56 | 
            +
                    groups_out: Optional[int] = None,
         | 
| 57 | 
            +
                    pre_norm: bool = True,
         | 
| 58 | 
            +
                    eps: float = 1e-6,
         | 
| 59 | 
            +
                    non_linearity: str = "swish",
         | 
| 60 | 
            +
                    time_embedding_norm: str = "default",  # default, scale_shift, ada_group, spatial
         | 
| 61 | 
            +
                    output_scale_factor: float = 1.0,
         | 
| 62 | 
            +
                    use_in_shortcut: Optional[bool] = None,
         | 
| 63 | 
            +
                    conv_shortcut_bias: bool = True,
         | 
| 64 | 
            +
                    conv_2d_out_channels: Optional[int] = None,
         | 
| 65 | 
            +
                ):
         | 
| 66 | 
            +
                    super().__init__()
         | 
| 67 | 
            +
                    self.pre_norm = pre_norm
         | 
| 68 | 
            +
                    self.pre_norm = True
         | 
| 69 | 
            +
                    self.in_channels = in_channels
         | 
| 70 | 
            +
                    out_channels = in_channels if out_channels is None else out_channels
         | 
| 71 | 
            +
                    self.out_channels = out_channels
         | 
| 72 | 
            +
                    self.use_conv_shortcut = conv_shortcut
         | 
| 73 | 
            +
                    self.output_scale_factor = output_scale_factor
         | 
| 74 | 
            +
                    self.time_embedding_norm = time_embedding_norm
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    linear_cls = nn.Linear
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    if groups_out is None:
         | 
| 79 | 
            +
                        groups_out = groups
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    if self.time_embedding_norm == "ada_group":
         | 
| 82 | 
            +
                        self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
         | 
| 83 | 
            +
                    elif self.time_embedding_norm == "spatial":
         | 
| 84 | 
            +
                        self.norm1 = SpatialNorm(in_channels, temb_channels)
         | 
| 85 | 
            +
                    else:
         | 
| 86 | 
            +
                        self.norm1 = CausalGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    if self.time_embedding_norm == "ada_group":
         | 
| 91 | 
            +
                        self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
         | 
| 92 | 
            +
                    elif self.time_embedding_norm == "spatial":
         | 
| 93 | 
            +
                        self.norm2 = SpatialNorm(out_channels, temb_channels)
         | 
| 94 | 
            +
                    else:
         | 
| 95 | 
            +
                        self.norm2 = CausalGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    self.dropout = torch.nn.Dropout(dropout)
         | 
| 98 | 
            +
                    conv_2d_out_channels = conv_2d_out_channels or out_channels
         | 
| 99 | 
            +
                    self.conv2 = CausalConv3d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    self.nonlinearity = get_activation(non_linearity)
         | 
| 102 | 
            +
                    self.upsample = self.downsample = None
         | 
| 103 | 
            +
                    self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    self.conv_shortcut = None
         | 
| 106 | 
            +
                    if self.use_in_shortcut:
         | 
| 107 | 
            +
                        self.conv_shortcut = CausalConv3d(
         | 
| 108 | 
            +
                            in_channels,
         | 
| 109 | 
            +
                            conv_2d_out_channels,
         | 
| 110 | 
            +
                            kernel_size=1,
         | 
| 111 | 
            +
                            stride=1,
         | 
| 112 | 
            +
                            bias=conv_shortcut_bias,
         | 
| 113 | 
            +
                        )
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                def forward(
         | 
| 116 | 
            +
                    self,
         | 
| 117 | 
            +
                    input_tensor: torch.FloatTensor,
         | 
| 118 | 
            +
                    temb: torch.FloatTensor = None,
         | 
| 119 | 
            +
                    is_init_image=True, 
         | 
| 120 | 
            +
                    temporal_chunk=False,
         | 
| 121 | 
            +
                ) -> torch.FloatTensor:
         | 
| 122 | 
            +
                    hidden_states = input_tensor
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
         | 
| 125 | 
            +
                        hidden_states = self.norm1(hidden_states, temb)
         | 
| 126 | 
            +
                    else:
         | 
| 127 | 
            +
                        hidden_states = self.norm1(hidden_states)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    hidden_states = self.nonlinearity(hidden_states)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    hidden_states = self.conv1(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    if temb is not None and self.time_embedding_norm == "default":
         | 
| 134 | 
            +
                        hidden_states = hidden_states + temb
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
         | 
| 137 | 
            +
                        hidden_states = self.norm2(hidden_states, temb)
         | 
| 138 | 
            +
                    else:
         | 
| 139 | 
            +
                        hidden_states = self.norm2(hidden_states)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    hidden_states = self.nonlinearity(hidden_states)
         | 
| 142 | 
            +
                    hidden_states = self.dropout(hidden_states)
         | 
| 143 | 
            +
                    hidden_states = self.conv2(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    if self.conv_shortcut is not None:
         | 
| 146 | 
            +
                        input_tensor = self.conv_shortcut(input_tensor, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    return output_tensor
         | 
| 151 | 
            +
             | 
| 152 | 
            +
             | 
| 153 | 
            +
            class ResnetBlock2D(nn.Module):
         | 
| 154 | 
            +
                r"""
         | 
| 155 | 
            +
                A Resnet block.
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                Parameters:
         | 
| 158 | 
            +
                    in_channels (`int`): The number of channels in the input.
         | 
| 159 | 
            +
                    out_channels (`int`, *optional*, default to be `None`):
         | 
| 160 | 
            +
                        The number of output channels for the first conv2d layer. If None, same as `in_channels`.
         | 
| 161 | 
            +
                    dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
         | 
| 162 | 
            +
                    temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
         | 
| 163 | 
            +
                    groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
         | 
| 164 | 
            +
                    groups_out (`int`, *optional*, default to None):
         | 
| 165 | 
            +
                        The number of groups to use for the second normalization layer. if set to None, same as `groups`.
         | 
| 166 | 
            +
                    eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
         | 
| 167 | 
            +
                    non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
         | 
| 168 | 
            +
                    time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
         | 
| 169 | 
            +
                        By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
         | 
| 170 | 
            +
                        "ada_group" for a stronger conditioning with scale and shift.
         | 
| 171 | 
            +
                    kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
         | 
| 172 | 
            +
                        [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
         | 
| 173 | 
            +
                    output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
         | 
| 174 | 
            +
                    use_in_shortcut (`bool`, *optional*, default to `True`):
         | 
| 175 | 
            +
                        If `True`, add a 1x1 nn.conv2d layer for skip-connection.
         | 
| 176 | 
            +
                    up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
         | 
| 177 | 
            +
                    down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
         | 
| 178 | 
            +
                    conv_shortcut_bias (`bool`, *optional*, default to `True`):  If `True`, adds a learnable bias to the
         | 
| 179 | 
            +
                        `conv_shortcut` output.
         | 
| 180 | 
            +
                    conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
         | 
| 181 | 
            +
                        If None, same as `out_channels`.
         | 
| 182 | 
            +
                """
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                def __init__(
         | 
| 185 | 
            +
                    self,
         | 
| 186 | 
            +
                    *,
         | 
| 187 | 
            +
                    in_channels: int,
         | 
| 188 | 
            +
                    out_channels: Optional[int] = None,
         | 
| 189 | 
            +
                    conv_shortcut: bool = False,
         | 
| 190 | 
            +
                    dropout: float = 0.0,
         | 
| 191 | 
            +
                    temb_channels: int = 512,
         | 
| 192 | 
            +
                    groups: int = 32,
         | 
| 193 | 
            +
                    groups_out: Optional[int] = None,
         | 
| 194 | 
            +
                    pre_norm: bool = True,
         | 
| 195 | 
            +
                    eps: float = 1e-6,
         | 
| 196 | 
            +
                    non_linearity: str = "swish",
         | 
| 197 | 
            +
                    time_embedding_norm: str = "default",  # default, scale_shift, ada_group, spatial
         | 
| 198 | 
            +
                    output_scale_factor: float = 1.0,
         | 
| 199 | 
            +
                    use_in_shortcut: Optional[bool] = None,
         | 
| 200 | 
            +
                    conv_shortcut_bias: bool = True,
         | 
| 201 | 
            +
                    conv_2d_out_channels: Optional[int] = None,
         | 
| 202 | 
            +
                ):
         | 
| 203 | 
            +
                    super().__init__()
         | 
| 204 | 
            +
                    self.pre_norm = pre_norm
         | 
| 205 | 
            +
                    self.pre_norm = True
         | 
| 206 | 
            +
                    self.in_channels = in_channels
         | 
| 207 | 
            +
                    out_channels = in_channels if out_channels is None else out_channels
         | 
| 208 | 
            +
                    self.out_channels = out_channels
         | 
| 209 | 
            +
                    self.use_conv_shortcut = conv_shortcut
         | 
| 210 | 
            +
                    self.output_scale_factor = output_scale_factor
         | 
| 211 | 
            +
                    self.time_embedding_norm = time_embedding_norm
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    linear_cls = nn.Linear
         | 
| 214 | 
            +
                    conv_cls = nn.Conv3d
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    if groups_out is None:
         | 
| 217 | 
            +
                        groups_out = groups
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    if self.time_embedding_norm == "ada_group":
         | 
| 220 | 
            +
                        self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
         | 
| 221 | 
            +
                    elif self.time_embedding_norm == "spatial":
         | 
| 222 | 
            +
                        self.norm1 = SpatialNorm(in_channels, temb_channels)
         | 
| 223 | 
            +
                    else:
         | 
| 224 | 
            +
                        self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                    self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    if self.time_embedding_norm == "ada_group":
         | 
| 229 | 
            +
                        self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
         | 
| 230 | 
            +
                    elif self.time_embedding_norm == "spatial":
         | 
| 231 | 
            +
                        self.norm2 = SpatialNorm(out_channels, temb_channels)
         | 
| 232 | 
            +
                    else:
         | 
| 233 | 
            +
                        self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                    self.dropout = torch.nn.Dropout(dropout)
         | 
| 236 | 
            +
                    conv_2d_out_channels = conv_2d_out_channels or out_channels
         | 
| 237 | 
            +
                    self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                    self.nonlinearity = get_activation(non_linearity)
         | 
| 240 | 
            +
                    self.upsample = self.downsample = None
         | 
| 241 | 
            +
                    self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                    self.conv_shortcut = None
         | 
| 244 | 
            +
                    if self.use_in_shortcut:
         | 
| 245 | 
            +
                        self.conv_shortcut = conv_cls(
         | 
| 246 | 
            +
                            in_channels,
         | 
| 247 | 
            +
                            conv_2d_out_channels,
         | 
| 248 | 
            +
                            kernel_size=1,
         | 
| 249 | 
            +
                            stride=1,
         | 
| 250 | 
            +
                            padding=0,
         | 
| 251 | 
            +
                            bias=conv_shortcut_bias,
         | 
| 252 | 
            +
                        )
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                def forward(
         | 
| 255 | 
            +
                    self,
         | 
| 256 | 
            +
                    input_tensor: torch.FloatTensor,
         | 
| 257 | 
            +
                    temb: torch.FloatTensor = None,
         | 
| 258 | 
            +
                    scale: float = 1.0,
         | 
| 259 | 
            +
                ) -> torch.FloatTensor:
         | 
| 260 | 
            +
                    hidden_states = input_tensor
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                    if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
         | 
| 263 | 
            +
                        hidden_states = self.norm1(hidden_states, temb)
         | 
| 264 | 
            +
                    else:
         | 
| 265 | 
            +
                        hidden_states = self.norm1(hidden_states)
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                    hidden_states = self.nonlinearity(hidden_states)
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                    hidden_states = self.conv1(hidden_states)
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                    if temb is not None and self.time_embedding_norm == "default":
         | 
| 272 | 
            +
                        hidden_states = hidden_states + temb
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                    if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
         | 
| 275 | 
            +
                        hidden_states = self.norm2(hidden_states, temb)
         | 
| 276 | 
            +
                    else:
         | 
| 277 | 
            +
                        hidden_states = self.norm2(hidden_states)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                    hidden_states = self.nonlinearity(hidden_states)
         | 
| 280 | 
            +
                    hidden_states = self.dropout(hidden_states)
         | 
| 281 | 
            +
                    hidden_states = self.conv2(hidden_states)
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                    if self.conv_shortcut is not None:
         | 
| 284 | 
            +
                        input_tensor = self.conv_shortcut(input_tensor)
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    return output_tensor
         | 
| 289 | 
            +
             | 
| 290 | 
            +
             | 
| 291 | 
            +
            class CausalDownsample2x(nn.Module):
         | 
| 292 | 
            +
                """A 2D downsampling layer with an optional convolution.
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                Parameters:
         | 
| 295 | 
            +
                    channels (`int`):
         | 
| 296 | 
            +
                        number of channels in the inputs and outputs.
         | 
| 297 | 
            +
                    use_conv (`bool`, default `False`):
         | 
| 298 | 
            +
                        option to use a convolution.
         | 
| 299 | 
            +
                    out_channels (`int`, optional):
         | 
| 300 | 
            +
                        number of output channels. Defaults to `channels`.
         | 
| 301 | 
            +
                    padding (`int`, default `1`):
         | 
| 302 | 
            +
                        padding for the convolution.
         | 
| 303 | 
            +
                    name (`str`, default `conv`):
         | 
| 304 | 
            +
                        name of the downsampling 2D layer.
         | 
| 305 | 
            +
                """
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                def __init__(
         | 
| 308 | 
            +
                    self,
         | 
| 309 | 
            +
                    channels: int,
         | 
| 310 | 
            +
                    use_conv: bool = True,
         | 
| 311 | 
            +
                    out_channels: Optional[int] = None,
         | 
| 312 | 
            +
                    name: str = "conv",
         | 
| 313 | 
            +
                    kernel_size=3,
         | 
| 314 | 
            +
                    bias=True,
         | 
| 315 | 
            +
                ):
         | 
| 316 | 
            +
                    super().__init__()
         | 
| 317 | 
            +
                    self.channels = channels
         | 
| 318 | 
            +
                    self.out_channels = out_channels or channels
         | 
| 319 | 
            +
                    self.use_conv = use_conv
         | 
| 320 | 
            +
                    stride = (1, 2, 2)
         | 
| 321 | 
            +
                    self.name = name
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                    if use_conv:
         | 
| 324 | 
            +
                        conv = CausalConv3d(
         | 
| 325 | 
            +
                            self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias
         | 
| 326 | 
            +
                        )
         | 
| 327 | 
            +
                    else:
         | 
| 328 | 
            +
                        assert self.channels == self.out_channels
         | 
| 329 | 
            +
                        conv = nn.AvgPool3d(kernel_size=stride, stride=stride)
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                    self.conv = conv
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                def forward(self, hidden_states: torch.FloatTensor, is_init_image=True, temporal_chunk=False) -> torch.FloatTensor:
         | 
| 334 | 
            +
                    assert hidden_states.shape[1] == self.channels
         | 
| 335 | 
            +
                    hidden_states = self.conv(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
         | 
| 336 | 
            +
                    return hidden_states
         | 
| 337 | 
            +
             | 
| 338 | 
            +
             | 
| 339 | 
            +
            class Downsample2D(nn.Module):
         | 
| 340 | 
            +
                """A 2D downsampling layer with an optional convolution.
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                Parameters:
         | 
| 343 | 
            +
                    channels (`int`):
         | 
| 344 | 
            +
                        number of channels in the inputs and outputs.
         | 
| 345 | 
            +
                    use_conv (`bool`, default `False`):
         | 
| 346 | 
            +
                        option to use a convolution.
         | 
| 347 | 
            +
                    out_channels (`int`, optional):
         | 
| 348 | 
            +
                        number of output channels. Defaults to `channels`.
         | 
| 349 | 
            +
                    padding (`int`, default `1`):
         | 
| 350 | 
            +
                        padding for the convolution.
         | 
| 351 | 
            +
                    name (`str`, default `conv`):
         | 
| 352 | 
            +
                        name of the downsampling 2D layer.
         | 
| 353 | 
            +
                """
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                def __init__(
         | 
| 356 | 
            +
                    self,
         | 
| 357 | 
            +
                    channels: int,
         | 
| 358 | 
            +
                    use_conv: bool = True,
         | 
| 359 | 
            +
                    out_channels: Optional[int] = None,
         | 
| 360 | 
            +
                    padding: int = 0,
         | 
| 361 | 
            +
                    name: str = "conv",
         | 
| 362 | 
            +
                    kernel_size=3,
         | 
| 363 | 
            +
                    bias=True,
         | 
| 364 | 
            +
                ):
         | 
| 365 | 
            +
                    super().__init__()
         | 
| 366 | 
            +
                    self.channels = channels
         | 
| 367 | 
            +
                    self.out_channels = out_channels or channels
         | 
| 368 | 
            +
                    self.use_conv = use_conv
         | 
| 369 | 
            +
                    self.padding = padding
         | 
| 370 | 
            +
                    stride = (1, 2, 2)
         | 
| 371 | 
            +
                    self.name = name
         | 
| 372 | 
            +
                    conv_cls = nn.Conv3d
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                    if use_conv:
         | 
| 375 | 
            +
                        conv = conv_cls(
         | 
| 376 | 
            +
                            self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
         | 
| 377 | 
            +
                        )
         | 
| 378 | 
            +
                    else:
         | 
| 379 | 
            +
                        assert self.channels == self.out_channels
         | 
| 380 | 
            +
                        conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                    self.conv = conv
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
         | 
| 385 | 
            +
                    assert hidden_states.shape[1] == self.channels
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                    if self.use_conv and self.padding == 0:
         | 
| 388 | 
            +
                        pad = (0, 1, 0, 1, 1, 1)
         | 
| 389 | 
            +
                        hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                    assert hidden_states.shape[1] == self.channels
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                    hidden_states = self.conv(hidden_states)
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                    return hidden_states
         | 
| 396 | 
            +
             | 
| 397 | 
            +
             | 
| 398 | 
            +
            class TemporalDownsample2x(nn.Module):
         | 
| 399 | 
            +
                """A Temporal downsampling layer with an optional convolution.
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                Parameters:
         | 
| 402 | 
            +
                    channels (`int`):
         | 
| 403 | 
            +
                        number of channels in the inputs and outputs.
         | 
| 404 | 
            +
                    use_conv (`bool`, default `False`):
         | 
| 405 | 
            +
                        option to use a convolution.
         | 
| 406 | 
            +
                    out_channels (`int`, optional):
         | 
| 407 | 
            +
                        number of output channels. Defaults to `channels`.
         | 
| 408 | 
            +
                    padding (`int`, default `1`):
         | 
| 409 | 
            +
                        padding for the convolution.
         | 
| 410 | 
            +
                    name (`str`, default `conv`):
         | 
| 411 | 
            +
                        name of the downsampling 2D layer.
         | 
| 412 | 
            +
                """
         | 
| 413 | 
            +
             | 
| 414 | 
            +
                def __init__(
         | 
| 415 | 
            +
                    self,
         | 
| 416 | 
            +
                    channels: int,
         | 
| 417 | 
            +
                    use_conv: bool = False,
         | 
| 418 | 
            +
                    out_channels: Optional[int] = None,
         | 
| 419 | 
            +
                    padding: int = 0,
         | 
| 420 | 
            +
                    kernel_size=3,
         | 
| 421 | 
            +
                    bias=True,
         | 
| 422 | 
            +
                ):
         | 
| 423 | 
            +
                    super().__init__()
         | 
| 424 | 
            +
                    self.channels = channels
         | 
| 425 | 
            +
                    self.out_channels = out_channels or channels
         | 
| 426 | 
            +
                    self.use_conv = use_conv
         | 
| 427 | 
            +
                    self.padding = padding
         | 
| 428 | 
            +
                    stride = (2, 1, 1)
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                    conv_cls = nn.Conv3d
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                    if use_conv:
         | 
| 433 | 
            +
                        conv = conv_cls(
         | 
| 434 | 
            +
                            self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
         | 
| 435 | 
            +
                        )
         | 
| 436 | 
            +
                    else:
         | 
| 437 | 
            +
                        raise NotImplementedError("Not implemented for temporal downsample without")
         | 
| 438 | 
            +
             | 
| 439 | 
            +
                    self.conv = conv
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
         | 
| 442 | 
            +
                    assert hidden_states.shape[1] == self.channels
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                    if self.use_conv and self.padding == 0:
         | 
| 445 | 
            +
                        if hidden_states.shape[2] == 1:
         | 
| 446 | 
            +
                            # image
         | 
| 447 | 
            +
                            pad = (1, 1, 1, 1, 1, 1)
         | 
| 448 | 
            +
                        else:
         | 
| 449 | 
            +
                            # video
         | 
| 450 | 
            +
                            pad = (1, 1, 1, 1, 0, 1)
         | 
| 451 | 
            +
             | 
| 452 | 
            +
                        hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
         | 
| 453 | 
            +
             | 
| 454 | 
            +
                    hidden_states = self.conv(hidden_states)
         | 
| 455 | 
            +
                    return hidden_states
         | 
| 456 | 
            +
             | 
| 457 | 
            +
             | 
| 458 | 
            +
            class CausalTemporalDownsample2x(nn.Module):
         | 
| 459 | 
            +
                """A Temporal downsampling layer with an optional convolution.
         | 
| 460 | 
            +
             | 
| 461 | 
            +
                Parameters:
         | 
| 462 | 
            +
                    channels (`int`):
         | 
| 463 | 
            +
                        number of channels in the inputs and outputs.
         | 
| 464 | 
            +
                    use_conv (`bool`, default `False`):
         | 
| 465 | 
            +
                        option to use a convolution.
         | 
| 466 | 
            +
                    out_channels (`int`, optional):
         | 
| 467 | 
            +
                        number of output channels. Defaults to `channels`.
         | 
| 468 | 
            +
                    padding (`int`, default `1`):
         | 
| 469 | 
            +
                        padding for the convolution.
         | 
| 470 | 
            +
                    name (`str`, default `conv`):
         | 
| 471 | 
            +
                        name of the downsampling 2D layer.
         | 
| 472 | 
            +
                """
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                def __init__(
         | 
| 475 | 
            +
                    self,
         | 
| 476 | 
            +
                    channels: int,
         | 
| 477 | 
            +
                    use_conv: bool = False,
         | 
| 478 | 
            +
                    out_channels: Optional[int] = None,
         | 
| 479 | 
            +
                    kernel_size=3,
         | 
| 480 | 
            +
                    bias=True,
         | 
| 481 | 
            +
                ):
         | 
| 482 | 
            +
                    super().__init__()
         | 
| 483 | 
            +
                    self.channels = channels
         | 
| 484 | 
            +
                    self.out_channels = out_channels or channels
         | 
| 485 | 
            +
                    self.use_conv = use_conv
         | 
| 486 | 
            +
                    stride = (2, 1, 1)
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                    conv_cls = nn.Conv3d
         | 
| 489 | 
            +
             | 
| 490 | 
            +
                    if use_conv:
         | 
| 491 | 
            +
                        conv = CausalConv3d(
         | 
| 492 | 
            +
                            self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias
         | 
| 493 | 
            +
                        )
         | 
| 494 | 
            +
                    else:
         | 
| 495 | 
            +
                        raise NotImplementedError("Not implemented for temporal downsample without")
         | 
| 496 | 
            +
             | 
| 497 | 
            +
                    self.conv = conv
         | 
| 498 | 
            +
             | 
| 499 | 
            +
                def forward(self, hidden_states: torch.FloatTensor, is_init_image=True, temporal_chunk=False) -> torch.FloatTensor:
         | 
| 500 | 
            +
                    assert hidden_states.shape[1] == self.channels
         | 
| 501 | 
            +
                    hidden_states = self.conv(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
         | 
| 502 | 
            +
                    return hidden_states
         | 
| 503 | 
            +
             | 
| 504 | 
            +
             | 
| 505 | 
            +
            class Upsample2D(nn.Module):
         | 
| 506 | 
            +
                """A 2D upsampling layer with an optional convolution.
         | 
| 507 | 
            +
             | 
| 508 | 
            +
                Parameters:
         | 
| 509 | 
            +
                    channels (`int`):
         | 
| 510 | 
            +
                        number of channels in the inputs and outputs.
         | 
| 511 | 
            +
                    use_conv (`bool`, default `False`):
         | 
| 512 | 
            +
                        option to use a convolution.
         | 
| 513 | 
            +
                    out_channels (`int`, optional):
         | 
| 514 | 
            +
                        number of output channels. Defaults to `channels`.
         | 
| 515 | 
            +
                    name (`str`, default `conv`):
         | 
| 516 | 
            +
                        name of the upsampling 2D layer.
         | 
| 517 | 
            +
                """
         | 
| 518 | 
            +
             | 
| 519 | 
            +
                def __init__(
         | 
| 520 | 
            +
                    self,
         | 
| 521 | 
            +
                    channels: int,
         | 
| 522 | 
            +
                    use_conv: bool = False,
         | 
| 523 | 
            +
                    out_channels: Optional[int] = None,
         | 
| 524 | 
            +
                    name: str = "conv",
         | 
| 525 | 
            +
                    kernel_size: Optional[int] = None,
         | 
| 526 | 
            +
                    padding=1,
         | 
| 527 | 
            +
                    bias=True,
         | 
| 528 | 
            +
                    interpolate=False,
         | 
| 529 | 
            +
                ):
         | 
| 530 | 
            +
                    super().__init__()
         | 
| 531 | 
            +
                    self.channels = channels
         | 
| 532 | 
            +
                    self.out_channels = out_channels or channels
         | 
| 533 | 
            +
                    self.use_conv = use_conv
         | 
| 534 | 
            +
                    self.name = name
         | 
| 535 | 
            +
                    self.interpolate = interpolate
         | 
| 536 | 
            +
                    conv_cls = nn.Conv3d
         | 
| 537 | 
            +
                    conv = None
         | 
| 538 | 
            +
                
         | 
| 539 | 
            +
                    if interpolate:
         | 
| 540 | 
            +
                        raise NotImplementedError("Not implemented for spatial upsample with interpolate")
         | 
| 541 | 
            +
                    else:
         | 
| 542 | 
            +
                        if kernel_size is None:
         | 
| 543 | 
            +
                            kernel_size = 3
         | 
| 544 | 
            +
                        conv = conv_cls(self.channels, self.out_channels * 4, kernel_size=kernel_size, padding=padding, bias=bias)
         | 
| 545 | 
            +
             | 
| 546 | 
            +
                    self.conv = conv
         | 
| 547 | 
            +
                    self.conv.apply(self._init_weights)
         | 
| 548 | 
            +
             | 
| 549 | 
            +
                def _init_weights(self, m):
         | 
| 550 | 
            +
                    if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)):
         | 
| 551 | 
            +
                        trunc_normal_(m.weight, std=.02)
         | 
| 552 | 
            +
                        if m.bias is not None:
         | 
| 553 | 
            +
                            nn.init.constant_(m.bias, 0)
         | 
| 554 | 
            +
                    elif isinstance(m, nn.LayerNorm):
         | 
| 555 | 
            +
                        nn.init.constant_(m.bias, 0)
         | 
| 556 | 
            +
                        nn.init.constant_(m.weight, 1.0)
         | 
| 557 | 
            +
             | 
| 558 | 
            +
                def forward(
         | 
| 559 | 
            +
                    self,
         | 
| 560 | 
            +
                    hidden_states: torch.FloatTensor,
         | 
| 561 | 
            +
                ) -> torch.FloatTensor:
         | 
| 562 | 
            +
                    assert hidden_states.shape[1] == self.channels
         | 
| 563 | 
            +
             | 
| 564 | 
            +
                    hidden_states = self.conv(hidden_states) 
         | 
| 565 | 
            +
                    hidden_states = rearrange(hidden_states, 'b (c p1 p2) t h w -> b c t (h p1) (w p2)', p1=2, p2=2)
         | 
| 566 | 
            +
             | 
| 567 | 
            +
                    return hidden_states
         | 
| 568 | 
            +
             | 
| 569 | 
            +
             | 
| 570 | 
            +
            class CausalUpsample2x(nn.Module):
         | 
| 571 | 
            +
                """A 2D upsampling layer with an optional convolution.
         | 
| 572 | 
            +
             | 
| 573 | 
            +
                Parameters:
         | 
| 574 | 
            +
                    channels (`int`):
         | 
| 575 | 
            +
                        number of channels in the inputs and outputs.
         | 
| 576 | 
            +
                    use_conv (`bool`, default `False`):
         | 
| 577 | 
            +
                        option to use a convolution.
         | 
| 578 | 
            +
                    out_channels (`int`, optional):
         | 
| 579 | 
            +
                        number of output channels. Defaults to `channels`.
         | 
| 580 | 
            +
                    name (`str`, default `conv`):
         | 
| 581 | 
            +
                        name of the upsampling 2D layer.
         | 
| 582 | 
            +
                """
         | 
| 583 | 
            +
             | 
| 584 | 
            +
                def __init__(
         | 
| 585 | 
            +
                    self,
         | 
| 586 | 
            +
                    channels: int,
         | 
| 587 | 
            +
                    use_conv: bool = False,
         | 
| 588 | 
            +
                    out_channels: Optional[int] = None,
         | 
| 589 | 
            +
                    name: str = "conv",
         | 
| 590 | 
            +
                    kernel_size: Optional[int] = 3,
         | 
| 591 | 
            +
                    bias=True,
         | 
| 592 | 
            +
                    interpolate=False,
         | 
| 593 | 
            +
                ):
         | 
| 594 | 
            +
                    super().__init__()
         | 
| 595 | 
            +
                    self.channels = channels
         | 
| 596 | 
            +
                    self.out_channels = out_channels or channels
         | 
| 597 | 
            +
                    self.use_conv = use_conv
         | 
| 598 | 
            +
                    self.name = name
         | 
| 599 | 
            +
                    self.interpolate = interpolate
         | 
| 600 | 
            +
                    conv = None
         | 
| 601 | 
            +
                
         | 
| 602 | 
            +
                    if interpolate:
         | 
| 603 | 
            +
                        raise NotImplementedError("Not implemented for spatial upsample with interpolate")
         | 
| 604 | 
            +
                    else:
         | 
| 605 | 
            +
                        conv = CausalConv3d(self.channels, self.out_channels * 4, kernel_size=kernel_size, stride=1, bias=bias)
         | 
| 606 | 
            +
             | 
| 607 | 
            +
                    self.conv = conv
         | 
| 608 | 
            +
             | 
| 609 | 
            +
                def forward(
         | 
| 610 | 
            +
                    self,
         | 
| 611 | 
            +
                    hidden_states: torch.FloatTensor,
         | 
| 612 | 
            +
                    is_init_image=True, temporal_chunk=False,
         | 
| 613 | 
            +
                ) -> torch.FloatTensor:
         | 
| 614 | 
            +
                    assert hidden_states.shape[1] == self.channels
         | 
| 615 | 
            +
                    hidden_states = self.conv(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk) 
         | 
| 616 | 
            +
                    hidden_states = rearrange(hidden_states, 'b (c p1 p2) t h w -> b c t (h p1) (w p2)', p1=2, p2=2)
         | 
| 617 | 
            +
                    return hidden_states
         | 
| 618 | 
            +
             | 
| 619 | 
            +
             | 
| 620 | 
            +
            class TemporalUpsample2x(nn.Module):
         | 
| 621 | 
            +
                """A 2D upsampling layer with an optional convolution.
         | 
| 622 | 
            +
             | 
| 623 | 
            +
                Parameters:
         | 
| 624 | 
            +
                    channels (`int`):
         | 
| 625 | 
            +
                        number of channels in the inputs and outputs.
         | 
| 626 | 
            +
                    use_conv (`bool`, default `False`):
         | 
| 627 | 
            +
                        option to use a convolution.
         | 
| 628 | 
            +
                    out_channels (`int`, optional):
         | 
| 629 | 
            +
                        number of output channels. Defaults to `channels`.
         | 
| 630 | 
            +
                    name (`str`, default `conv`):
         | 
| 631 | 
            +
                        name of the upsampling 2D layer.
         | 
| 632 | 
            +
                """
         | 
| 633 | 
            +
             | 
| 634 | 
            +
                def __init__(
         | 
| 635 | 
            +
                    self,
         | 
| 636 | 
            +
                    channels: int,
         | 
| 637 | 
            +
                    use_conv: bool = True,
         | 
| 638 | 
            +
                    out_channels: Optional[int] = None,
         | 
| 639 | 
            +
                    kernel_size: Optional[int] = None,
         | 
| 640 | 
            +
                    padding=1,
         | 
| 641 | 
            +
                    bias=True,
         | 
| 642 | 
            +
                    interpolate=False,
         | 
| 643 | 
            +
                ):
         | 
| 644 | 
            +
                    super().__init__()
         | 
| 645 | 
            +
                    self.channels = channels
         | 
| 646 | 
            +
                    self.out_channels = out_channels or channels
         | 
| 647 | 
            +
                    self.use_conv = use_conv
         | 
| 648 | 
            +
                    self.interpolate = interpolate
         | 
| 649 | 
            +
                    conv_cls = nn.Conv3d
         | 
| 650 | 
            +
             | 
| 651 | 
            +
                    conv = None
         | 
| 652 | 
            +
                    if interpolate:
         | 
| 653 | 
            +
                        raise NotImplementedError("Not implemented for spatial upsample with interpolate")
         | 
| 654 | 
            +
                    else:
         | 
| 655 | 
            +
                        # depth to space operator
         | 
| 656 | 
            +
                        if kernel_size is None:
         | 
| 657 | 
            +
                            kernel_size = 3
         | 
| 658 | 
            +
                        conv = conv_cls(self.channels, self.out_channels * 2, kernel_size=kernel_size, padding=padding, bias=bias)
         | 
| 659 | 
            +
             | 
| 660 | 
            +
                    self.conv = conv
         | 
| 661 | 
            +
             | 
| 662 | 
            +
                def forward(
         | 
| 663 | 
            +
                    self,
         | 
| 664 | 
            +
                    hidden_states: torch.FloatTensor,
         | 
| 665 | 
            +
                    is_image: bool = False,
         | 
| 666 | 
            +
                ) -> torch.FloatTensor:
         | 
| 667 | 
            +
                    assert hidden_states.shape[1] == self.channels
         | 
| 668 | 
            +
                    t = hidden_states.shape[2]
         | 
| 669 | 
            +
                    hidden_states = self.conv(hidden_states) 
         | 
| 670 | 
            +
                    hidden_states = rearrange(hidden_states, 'b (c p) t h w -> b c (p t) h w', p=2)
         | 
| 671 | 
            +
             | 
| 672 | 
            +
                    if t == 1 and is_image:
         | 
| 673 | 
            +
                        hidden_states = hidden_states[:, :, 1:]
         | 
| 674 | 
            +
             | 
| 675 | 
            +
                    return hidden_states
         | 
| 676 | 
            +
             | 
| 677 | 
            +
             | 
| 678 | 
            +
            class CausalTemporalUpsample2x(nn.Module):
         | 
| 679 | 
            +
                """A 2D upsampling layer with an optional convolution.
         | 
| 680 | 
            +
             | 
| 681 | 
            +
                Parameters:
         | 
| 682 | 
            +
                    channels (`int`):
         | 
| 683 | 
            +
                        number of channels in the inputs and outputs.
         | 
| 684 | 
            +
                    use_conv (`bool`, default `False`):
         | 
| 685 | 
            +
                        option to use a convolution.
         | 
| 686 | 
            +
                    out_channels (`int`, optional):
         | 
| 687 | 
            +
                        number of output channels. Defaults to `channels`.
         | 
| 688 | 
            +
                    name (`str`, default `conv`):
         | 
| 689 | 
            +
                        name of the upsampling 2D layer.
         | 
| 690 | 
            +
                """
         | 
| 691 | 
            +
             | 
| 692 | 
            +
                def __init__(
         | 
| 693 | 
            +
                    self,
         | 
| 694 | 
            +
                    channels: int,
         | 
| 695 | 
            +
                    use_conv: bool = True,
         | 
| 696 | 
            +
                    out_channels: Optional[int] = None,
         | 
| 697 | 
            +
                    kernel_size: Optional[int] = 3,
         | 
| 698 | 
            +
                    bias=True,
         | 
| 699 | 
            +
                    interpolate=False,
         | 
| 700 | 
            +
                ):
         | 
| 701 | 
            +
                    super().__init__()
         | 
| 702 | 
            +
                    self.channels = channels
         | 
| 703 | 
            +
                    self.out_channels = out_channels or channels
         | 
| 704 | 
            +
                    self.use_conv = use_conv
         | 
| 705 | 
            +
                    self.interpolate = interpolate
         | 
| 706 | 
            +
             | 
| 707 | 
            +
                    conv = None
         | 
| 708 | 
            +
                    if interpolate:
         | 
| 709 | 
            +
                        raise NotImplementedError("Not implemented for spatial upsample with interpolate")
         | 
| 710 | 
            +
                    else:
         | 
| 711 | 
            +
                        # depth to space operator
         | 
| 712 | 
            +
                        conv = CausalConv3d(self.channels, self.out_channels * 2, kernel_size=kernel_size, stride=1, bias=bias)
         | 
| 713 | 
            +
             | 
| 714 | 
            +
                    self.conv = conv
         | 
| 715 | 
            +
             | 
| 716 | 
            +
                def forward(
         | 
| 717 | 
            +
                    self,
         | 
| 718 | 
            +
                    hidden_states: torch.FloatTensor,
         | 
| 719 | 
            +
                    is_init_image=True, temporal_chunk=False,
         | 
| 720 | 
            +
                ) -> torch.FloatTensor:
         | 
| 721 | 
            +
                    assert hidden_states.shape[1] == self.channels
         | 
| 722 | 
            +
                    t = hidden_states.shape[2]
         | 
| 723 | 
            +
                    hidden_states = self.conv(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk) 
         | 
| 724 | 
            +
                    hidden_states = rearrange(hidden_states, 'b (c p) t h w -> b c (t p) h w', p=2)
         | 
| 725 | 
            +
             | 
| 726 | 
            +
                    if is_init_image:
         | 
| 727 | 
            +
                        hidden_states = hidden_states[:, :, 1:]
         | 
| 728 | 
            +
             | 
| 729 | 
            +
                    return hidden_states
         | 
 
			

