Spaces:
				
			
			
	
			
			
		Configuration error
		
	
	
	
			
			
	
	
	
	
		
		
		Configuration error
		
	
		hanshu.yan
		
	commited on
		
		
					Commit 
							
							·
						
						2ec72fb
	
1
								Parent(s):
							
							b83e3cf
								
add app.py
Browse files- LICENSE +21 -0
 - README.md +80 -12
 - app.py +182 -0
 - gradio_app.py +188 -0
 - output/.DS_Store +0 -0
 - output/0/input.png +0 -0
 - output/0/mesh.obj +0 -0
 - requirements.txt +17 -0
 - requirements2.txt +9 -0
 - run.py +162 -0
 - src/__pycache__/__init__.cpython-38.pyc +0 -0
 - src/__pycache__/scheduler_perflow.cpython-310.pyc +0 -0
 - src/__pycache__/scheduler_perflow.cpython-38.pyc +0 -0
 - src/__pycache__/utils_perflow.cpython-38.pyc +0 -0
 - src/laion_bytenas.py +257 -0
 - src/pfode_solver.py +120 -0
 - src/scheduler_perflow.py +343 -0
 - src/utils_perflow.py +77 -0
 - test.yaml +10 -0
 - tsr/__pycache__/system.cpython-310.pyc +0 -0
 - tsr/__pycache__/system.cpython-38.pyc +0 -0
 - tsr/__pycache__/utils.cpython-310.pyc +0 -0
 - tsr/__pycache__/utils.cpython-38.pyc +0 -0
 - tsr/models/__pycache__/isosurface.cpython-310.pyc +0 -0
 - tsr/models/__pycache__/isosurface.cpython-38.pyc +0 -0
 - tsr/models/__pycache__/nerf_renderer.cpython-310.pyc +0 -0
 - tsr/models/__pycache__/nerf_renderer.cpython-38.pyc +0 -0
 - tsr/models/__pycache__/network_utils.cpython-310.pyc +0 -0
 - tsr/models/__pycache__/network_utils.cpython-38.pyc +0 -0
 - tsr/models/isosurface.py +52 -0
 - tsr/models/nerf_renderer.py +180 -0
 - tsr/models/network_utils.py +124 -0
 - tsr/models/tokenizers/__pycache__/image.cpython-310.pyc +0 -0
 - tsr/models/tokenizers/__pycache__/image.cpython-38.pyc +0 -0
 - tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc +0 -0
 - tsr/models/tokenizers/__pycache__/triplane.cpython-38.pyc +0 -0
 - tsr/models/tokenizers/image.py +66 -0
 - tsr/models/tokenizers/triplane.py +45 -0
 - tsr/models/transformer/__pycache__/attention.cpython-310.pyc +0 -0
 - tsr/models/transformer/__pycache__/attention.cpython-38.pyc +0 -0
 - tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc +0 -0
 - tsr/models/transformer/__pycache__/basic_transformer_block.cpython-38.pyc +0 -0
 - tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc +0 -0
 - tsr/models/transformer/__pycache__/transformer_1d.cpython-38.pyc +0 -0
 - tsr/models/transformer/attention.py +653 -0
 - tsr/models/transformer/basic_transformer_block.py +334 -0
 - tsr/models/transformer/transformer_1d.py +219 -0
 - tsr/system.py +203 -0
 - tsr/utils.py +474 -0
 
    	
        LICENSE
    ADDED
    
    | 
         @@ -0,0 +1,21 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            MIT License
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            Copyright (c) 2024 Tripo AI & Stability AI
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            Permission is hereby granted, free of charge, to any person obtaining a copy
         
     | 
| 6 | 
         
            +
            of this software and associated documentation files (the "Software"), to deal
         
     | 
| 7 | 
         
            +
            in the Software without restriction, including without limitation the rights
         
     | 
| 8 | 
         
            +
            to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         
     | 
| 9 | 
         
            +
            copies of the Software, and to permit persons to whom the Software is
         
     | 
| 10 | 
         
            +
            furnished to do so, subject to the following conditions:
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            The above copyright notice and this permission notice shall be included in all
         
     | 
| 13 | 
         
            +
            copies or substantial portions of the Software.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         
     | 
| 16 | 
         
            +
            IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         
     | 
| 17 | 
         
            +
            FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         
     | 
| 18 | 
         
            +
            AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         
     | 
| 19 | 
         
            +
            LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         
     | 
| 20 | 
         
            +
            OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         
     | 
| 21 | 
         
            +
            SOFTWARE.
         
     | 
    	
        README.md
    CHANGED
    
    | 
         @@ -1,12 +1,80 @@ 
     | 
|
| 1 | 
         
            -
             
     | 
| 2 | 
         
            -
             
     | 
| 3 | 
         
            -
             
     | 
| 4 | 
         
            -
             
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
             
     | 
| 7 | 
         
            -
             
     | 
| 8 | 
         
            -
             
     | 
| 9 | 
         
            -
             
     | 
| 10 | 
         
            -
             
     | 
| 11 | 
         
            -
             
     | 
| 12 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # TripoSR <a href="https://huggingface.co/stabilityai/TripoSR"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Model_Card-Huggingface-orange"></a> <a href="https://huggingface.co/spaces/stabilityai/TripoSR"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Gradio%20Demo-Huggingface-orange"></a> <a href="https://arxiv.org/abs/2403.02151"><img src="https://img.shields.io/badge/Arxiv-2403.02151-B31B1B.svg"></a>
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            <div align="center">
         
     | 
| 4 | 
         
            +
              <img src="figures/teaser800.gif" alt="Teaser Video">
         
     | 
| 5 | 
         
            +
            </div>
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            This is the official codebase for **TripoSR**, a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, collaboratively developed by [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
         
     | 
| 8 | 
         
            +
            <br><br>
         
     | 
| 9 | 
         
            +
            Leveraging the principles of the [Large Reconstruction Model (LRM)](https://yiconghong.me/LRM/), TripoSR brings to the table key advancements that significantly boost both the speed and quality of 3D reconstruction. Our model is distinguished by its ability to rapidly process inputs, generating high-quality 3D models in less than 0.5 seconds on an NVIDIA A100 GPU. TripoSR has exhibited superior performance in both qualitative and quantitative evaluations, outperforming other open-source alternatives across multiple public datasets. The figures below illustrate visual comparisons and metrics showcasing TripoSR's performance relative to other leading models. Details about the model architecture, training process, and comparisons can be found in this [technical report](https://arxiv.org/abs/2403.02151).
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            <!--
         
     | 
| 12 | 
         
            +
            <div align="center">
         
     | 
| 13 | 
         
            +
              <img src="figures/comparison800.gif" alt="Teaser Video">
         
     | 
| 14 | 
         
            +
            </div>
         
     | 
| 15 | 
         
            +
            -->
         
     | 
| 16 | 
         
            +
            <p align="center">
         
     | 
| 17 | 
         
            +
                <img width="800" src="figures/visual_comparisons.jpg"/>
         
     | 
| 18 | 
         
            +
            </p>
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            <p align="center">
         
     | 
| 21 | 
         
            +
                <img width="450" src="figures/scatter-comparison.png"/>
         
     | 
| 22 | 
         
            +
            </p>
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            The model is released under the MIT license, which includes the source code, pretrained models, and an interactive online demo. Our goal is to empower researchers, developers, and creatives to push the boundaries of what's possible in 3D generative AI and 3D content creation.
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            ## Getting Started
         
     | 
| 28 | 
         
            +
            ### Installation
         
     | 
| 29 | 
         
            +
            - Python >= 3.8
         
     | 
| 30 | 
         
            +
            - Install CUDA if available
         
     | 
| 31 | 
         
            +
            - Install PyTorch according to your platform: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/) **[Please make sure that the locally-installed CUDA major version matches the PyTorch-shipped CUDA major version. For example if you have CUDA 11.x installed, make sure to install PyTorch compiled with CUDA 11.x.]**
         
     | 
| 32 | 
         
            +
            - Update setuptools by `pip install --upgrade setuptools`
         
     | 
| 33 | 
         
            +
            - Install other dependencies by `pip install -r requirements.txt`
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            ### Manual Inference
         
     | 
| 36 | 
         
            +
            ```sh
         
     | 
| 37 | 
         
            +
            python run.py examples/chair.png --output-dir output/
         
     | 
| 38 | 
         
            +
            ```
         
     | 
| 39 | 
         
            +
            This will save the reconstructed 3D model to `output/`. You can also specify more than one image path separated by spaces. The default options takes about **6GB VRAM** for a single image input.
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            For detailed usage of this script, use `python run.py --help`.
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            ### Local Gradio App
         
     | 
| 44 | 
         
            +
            Install Gradio:
         
     | 
| 45 | 
         
            +
            ```sh
         
     | 
| 46 | 
         
            +
            pip install gradio
         
     | 
| 47 | 
         
            +
            ```
         
     | 
| 48 | 
         
            +
            Start the Gradio App:
         
     | 
| 49 | 
         
            +
            ```sh
         
     | 
| 50 | 
         
            +
            python gradio_app.py
         
     | 
| 51 | 
         
            +
            ```
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
            ## Troubleshooting
         
     | 
| 54 | 
         
            +
            > AttributeError: module 'torchmcubes_module' has no attribute 'mcubes_cuda'
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
            or
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            > torchmcubes was not compiled with CUDA support, use CPU version instead.
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
            This is because `torchmcubes` is compiled without CUDA support. Please make sure that 
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
            - The locally-installed CUDA major version matches the PyTorch-shipped CUDA major version. For example if you have CUDA 11.x installed, make sure to install PyTorch compiled with CUDA 11.x.
         
     | 
| 63 | 
         
            +
            - `setuptools>=49.6.0`. If not, upgrade by `pip install --upgrade setuptools`.
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            Then re-install `torchmcubes` by:
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            ```sh
         
     | 
| 68 | 
         
            +
            pip uninstall torchmcubes
         
     | 
| 69 | 
         
            +
            pip install git+https://github.com/tatsy/torchmcubes.git
         
     | 
| 70 | 
         
            +
            ```
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
            ## Citation
         
     | 
| 73 | 
         
            +
            ```BibTeX
         
     | 
| 74 | 
         
            +
            @article{TripoSR2024,
         
     | 
| 75 | 
         
            +
              title={TripoSR: Fast 3D Object Reconstruction from a Single Image},
         
     | 
| 76 | 
         
            +
              author={Tochilkin, Dmitry and Pankratz, David and Liu, Zexiang and Huang, Zixuan and and Letts, Adam and Li, Yangguang and Liang, Ding and Laforte, Christian and Jampani, Varun and Cao, Yan-Pei},
         
     | 
| 77 | 
         
            +
              journal={arXiv preprint arXiv:2403.02151},
         
     | 
| 78 | 
         
            +
              year={2024}
         
     | 
| 79 | 
         
            +
            }
         
     | 
| 80 | 
         
            +
            ```
         
     | 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,182 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os, logging, time, argparse, random, tempfile, rembg
         
     | 
| 2 | 
         
            +
            import gradio as gr
         
     | 
| 3 | 
         
            +
            import numpy as np
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            from PIL import Image
         
     | 
| 6 | 
         
            +
            from functools import partial
         
     | 
| 7 | 
         
            +
            from tsr.system import TSR
         
     | 
| 8 | 
         
            +
            from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            from src.scheduler_perflow import PeRFlowScheduler
         
     | 
| 11 | 
         
            +
            from diffusers import StableDiffusionPipeline, UNet2DConditionModel
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            def merge_delta_weights_into_unet(pipe, delta_weights, org_alpha = 1.0):
         
     | 
| 14 | 
         
            +
                unet_weights = pipe.unet.state_dict()
         
     | 
| 15 | 
         
            +
                for key in delta_weights.keys():
         
     | 
| 16 | 
         
            +
                    dtype = unet_weights[key].dtype
         
     | 
| 17 | 
         
            +
                    try:
         
     | 
| 18 | 
         
            +
                        unet_weights[key] = org_alpha * unet_weights[key].to(dtype=delta_weights[key].dtype) + delta_weights[key].to(device=unet_weights[key].device)
         
     | 
| 19 | 
         
            +
                    except:
         
     | 
| 20 | 
         
            +
                        unet_weights[key] = unet_weights[key].to(dtype=delta_weights[key].dtype)
         
     | 
| 21 | 
         
            +
                    unet_weights[key] = unet_weights[key].to(dtype)
         
     | 
| 22 | 
         
            +
                pipe.unet.load_state_dict(unet_weights, strict=True)
         
     | 
| 23 | 
         
            +
                return pipe
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            def setup_seed(seed):
         
     | 
| 26 | 
         
            +
                random.seed(seed)
         
     | 
| 27 | 
         
            +
                np.random.seed(seed)
         
     | 
| 28 | 
         
            +
                torch.manual_seed(seed)
         
     | 
| 29 | 
         
            +
                torch.cuda.manual_seed_all(seed)
         
     | 
| 30 | 
         
            +
                torch.backends.cudnn.deterministic = True
         
     | 
| 31 | 
         
            +
                
         
     | 
| 32 | 
         
            +
            if torch.cuda.is_available():
         
     | 
| 33 | 
         
            +
                device = "cuda:0"
         
     | 
| 34 | 
         
            +
            else:
         
     | 
| 35 | 
         
            +
                device = "cpu"
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            ### TripoSR
         
     | 
| 38 | 
         
            +
            model = TSR.from_pretrained(
         
     | 
| 39 | 
         
            +
                "stabilityai/TripoSR",
         
     | 
| 40 | 
         
            +
                config_name="config.yaml",
         
     | 
| 41 | 
         
            +
                weight_name="model.ckpt",
         
     | 
| 42 | 
         
            +
            )
         
     | 
| 43 | 
         
            +
            # adjust the chunk size to balance between speed and memory usage
         
     | 
| 44 | 
         
            +
            model.renderer.set_chunk_size(8192)
         
     | 
| 45 | 
         
            +
            model.to(device)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            ### PeRFlow-T2I
         
     | 
| 49 | 
         
            +
            # pipe_t2i = StableDiffusionPipeline.from_pretrained("Lykon/dreamshaper-8", torch_dtype=torch.float16, safety_checker=None)
         
     | 
| 50 | 
         
            +
            pipe_t2i = StableDiffusionPipeline.from_pretrained("stablediffusionapi/disney-pixar-cartoon", torch_dtype=torch.float16, safety_checker=None)
         
     | 
| 51 | 
         
            +
            delta_weights = UNet2DConditionModel.from_pretrained("hansyan/piecewise-rectified-flow-delta-weights", torch_dtype=torch.float16, variant="v0-1",).state_dict()
         
     | 
| 52 | 
         
            +
            pipe_t2i = merge_delta_weights_into_unet(pipe_t2i, delta_weights)
         
     | 
| 53 | 
         
            +
            pipe_t2i.scheduler = PeRFlowScheduler.from_config(pipe_t2i.scheduler.config, prediction_type="epsilon", num_time_windows=4)
         
     | 
| 54 | 
         
            +
            pipe_t2i.to('cuda:0', torch.float16)
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
            ### gradio
         
     | 
| 58 | 
         
            +
            rembg_session = rembg.new_session()
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
            def generate(text, seed):
         
     | 
| 61 | 
         
            +
                def fill_background(image):
         
     | 
| 62 | 
         
            +
                    image = np.array(image).astype(np.float32) / 255.0
         
     | 
| 63 | 
         
            +
                    image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
         
     | 
| 64 | 
         
            +
                    image = Image.fromarray((image * 255.0).astype(np.uint8))
         
     | 
| 65 | 
         
            +
                    return image
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                setup_seed(int(seed))
         
     | 
| 68 | 
         
            +
                # text = text
         
     | 
| 69 | 
         
            +
                samples = pipe_t2i(
         
     | 
| 70 | 
         
            +
                        prompt              = [text],
         
     | 
| 71 | 
         
            +
                        negative_prompt     = ["distorted, blur, low-quality, haze, out of focus"],
         
     | 
| 72 | 
         
            +
                        height              = 512,
         
     | 
| 73 | 
         
            +
                        width               = 512,
         
     | 
| 74 | 
         
            +
                        # num_inference_steps = 4,
         
     | 
| 75 | 
         
            +
                        # guidance_scale      = 4.5,
         
     | 
| 76 | 
         
            +
                        num_inference_steps = 6,
         
     | 
| 77 | 
         
            +
                        guidance_scale      = 7,
         
     | 
| 78 | 
         
            +
                        output_type         = 'pt',
         
     | 
| 79 | 
         
            +
                    ).images
         
     | 
| 80 | 
         
            +
                samples = torch.nn.functional.interpolate(samples, size=768, mode='bilinear')
         
     | 
| 81 | 
         
            +
                samples = samples.squeeze(0).permute(1, 2, 0).cpu().numpy()*255.
         
     | 
| 82 | 
         
            +
                samples = samples.astype(np.uint8)
         
     | 
| 83 | 
         
            +
                samples = Image.fromarray(samples[:, :, :3])
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                image = remove_background(samples, rembg_session)
         
     | 
| 86 | 
         
            +
                image = resize_foreground(image, 0.85)
         
     | 
| 87 | 
         
            +
                image = fill_background(image)
         
     | 
| 88 | 
         
            +
                return image
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            def render(image, mc_resolution=256, formats=["obj"]):
         
     | 
| 91 | 
         
            +
                scene_codes = model(image, device=device)
         
     | 
| 92 | 
         
            +
                mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
         
     | 
| 93 | 
         
            +
                mesh = to_gradio_3d_orientation(mesh)
         
     | 
| 94 | 
         
            +
                rv = []
         
     | 
| 95 | 
         
            +
                for format in formats:
         
     | 
| 96 | 
         
            +
                    mesh_path = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
         
     | 
| 97 | 
         
            +
                    mesh.export(mesh_path.name)
         
     | 
| 98 | 
         
            +
                    rv.append(mesh_path.name)
         
     | 
| 99 | 
         
            +
                return rv[0]
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
            # warm up
         
     | 
| 102 | 
         
            +
            _ = generate("a bird", 42)
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
            # layout
         
     | 
| 105 | 
         
            +
            css = """
         
     | 
| 106 | 
         
            +
            h1 {
         
     | 
| 107 | 
         
            +
                text-align: center;
         
     | 
| 108 | 
         
            +
                display:block;
         
     | 
| 109 | 
         
            +
            }
         
     | 
| 110 | 
         
            +
            h2 {
         
     | 
| 111 | 
         
            +
                text-align: center;
         
     | 
| 112 | 
         
            +
                display:block;
         
     | 
| 113 | 
         
            +
            }
         
     | 
| 114 | 
         
            +
            h3 {
         
     | 
| 115 | 
         
            +
                text-align: center;
         
     | 
| 116 | 
         
            +
                display:block;
         
     | 
| 117 | 
         
            +
            }
         
     | 
| 118 | 
         
            +
            """
         
     | 
| 119 | 
         
            +
            with gr.Blocks(title="TripoSR", css=css) as interface:
         
     | 
| 120 | 
         
            +
                gr.Markdown(
         
     | 
| 121 | 
         
            +
                """
         
     | 
| 122 | 
         
            +
                # Instant Text-to-3D Mesh Demo
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                ### [PeRFlow](https://github.com/magic-research/piecewise-rectified-flow)-T2I  +  [TripoSR](https://github.com/VAST-AI-Research/TripoSR)
         
     | 
| 125 | 
         
            +
                
         
     | 
| 126 | 
         
            +
                Two-stage synthesis: 1) generating images by PeRFlow-T2I with 6-step inference; 2) rendering 3D assests.
         
     | 
| 127 | 
         
            +
                """
         
     | 
| 128 | 
         
            +
                )
         
     | 
| 129 | 
         
            +
                
         
     | 
| 130 | 
         
            +
                with gr.Column():
         
     | 
| 131 | 
         
            +
                    with gr.Row():
         
     | 
| 132 | 
         
            +
                            output_image = gr.Image(label='Generated Image', height=384, width=384)
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                            output_model_obj = gr.Model3D(
         
     | 
| 135 | 
         
            +
                                label="Output 3D Model (OBJ Format)",
         
     | 
| 136 | 
         
            +
                                interactive=False,
         
     | 
| 137 | 
         
            +
                                height=384, width=384,
         
     | 
| 138 | 
         
            +
                        )
         
     | 
| 139 | 
         
            +
                
         
     | 
| 140 | 
         
            +
                with gr.Row():
         
     | 
| 141 | 
         
            +
                    textbox = gr.Textbox(label="Input Prompt", value="a colorful bird")
         
     | 
| 142 | 
         
            +
                    seed = gr.Textbox(label="Random Seed", value=42)
         
     | 
| 143 | 
         
            +
                
         
     | 
| 144 | 
         
            +
                # activate
         
     | 
| 145 | 
         
            +
                textbox.submit(
         
     | 
| 146 | 
         
            +
                    fn=generate,
         
     | 
| 147 | 
         
            +
                    inputs=[textbox, seed],
         
     | 
| 148 | 
         
            +
                    outputs=[output_image],
         
     | 
| 149 | 
         
            +
                ).success(
         
     | 
| 150 | 
         
            +
                    fn=render,
         
     | 
| 151 | 
         
            +
                    inputs=[output_image],
         
     | 
| 152 | 
         
            +
                    outputs=[output_model_obj],
         
     | 
| 153 | 
         
            +
                )
         
     | 
| 154 | 
         
            +
                
         
     | 
| 155 | 
         
            +
                seed.submit(
         
     | 
| 156 | 
         
            +
                    fn=generate,
         
     | 
| 157 | 
         
            +
                    inputs=[textbox, seed],
         
     | 
| 158 | 
         
            +
                    outputs=[output_image],
         
     | 
| 159 | 
         
            +
                ).success(
         
     | 
| 160 | 
         
            +
                    fn=render,
         
     | 
| 161 | 
         
            +
                    inputs=[output_image],
         
     | 
| 162 | 
         
            +
                    outputs=[output_model_obj],
         
     | 
| 163 | 
         
            +
                )
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 168 | 
         
            +
                parser = argparse.ArgumentParser()
         
     | 
| 169 | 
         
            +
                parser.add_argument('--username', type=str, default=None, help='Username for authentication')
         
     | 
| 170 | 
         
            +
                parser.add_argument('--password', type=str, default=None, help='Password for authentication')
         
     | 
| 171 | 
         
            +
                parser.add_argument('--port', type=int, default=7860, help='Port to run the server listener on')
         
     | 
| 172 | 
         
            +
                parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
         
     | 
| 173 | 
         
            +
                parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
         
     | 
| 174 | 
         
            +
                parser.add_argument("--queuesize", type=int, default=1, help="launch gradio queue max_size")
         
     | 
| 175 | 
         
            +
                args = parser.parse_args()
         
     | 
| 176 | 
         
            +
                interface.queue(max_size=args.queuesize)
         
     | 
| 177 | 
         
            +
                interface.launch(
         
     | 
| 178 | 
         
            +
                    auth=(args.username, args.password) if (args.username and args.password) else None,
         
     | 
| 179 | 
         
            +
                    share=args.share,
         
     | 
| 180 | 
         
            +
                    server_name="0.0.0.0" if args.listen else None, 
         
     | 
| 181 | 
         
            +
                    server_port=args.port
         
     | 
| 182 | 
         
            +
                )
         
     | 
    	
        gradio_app.py
    ADDED
    
    | 
         @@ -0,0 +1,188 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import logging
         
     | 
| 2 | 
         
            +
            import os
         
     | 
| 3 | 
         
            +
            import tempfile
         
     | 
| 4 | 
         
            +
            import time
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import gradio as gr
         
     | 
| 7 | 
         
            +
            import numpy as np
         
     | 
| 8 | 
         
            +
            import rembg
         
     | 
| 9 | 
         
            +
            import torch
         
     | 
| 10 | 
         
            +
            from PIL import Image
         
     | 
| 11 | 
         
            +
            from functools import partial
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            from tsr.system import TSR
         
     | 
| 14 | 
         
            +
            from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            import argparse
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            if torch.cuda.is_available():
         
     | 
| 20 | 
         
            +
                device = "cuda:0"
         
     | 
| 21 | 
         
            +
            else:
         
     | 
| 22 | 
         
            +
                device = "cpu"
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            model = TSR.from_pretrained(
         
     | 
| 25 | 
         
            +
                "stabilityai/TripoSR",
         
     | 
| 26 | 
         
            +
                config_name="config.yaml",
         
     | 
| 27 | 
         
            +
                weight_name="model.ckpt",
         
     | 
| 28 | 
         
            +
            )
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            # adjust the chunk size to balance between speed and memory usage
         
     | 
| 31 | 
         
            +
            model.renderer.set_chunk_size(8192)
         
     | 
| 32 | 
         
            +
            model.to(device)
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            rembg_session = rembg.new_session()
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            def check_input_image(input_image):
         
     | 
| 38 | 
         
            +
                if input_image is None:
         
     | 
| 39 | 
         
            +
                    raise gr.Error("No image uploaded!")
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            def preprocess(input_image, do_remove_background, foreground_ratio):
         
     | 
| 43 | 
         
            +
                def fill_background(image):
         
     | 
| 44 | 
         
            +
                    image = np.array(image).astype(np.float32) / 255.0
         
     | 
| 45 | 
         
            +
                    image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
         
     | 
| 46 | 
         
            +
                    image = Image.fromarray((image * 255.0).astype(np.uint8))
         
     | 
| 47 | 
         
            +
                    return image
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                if do_remove_background:
         
     | 
| 50 | 
         
            +
                    image = input_image.convert("RGB")
         
     | 
| 51 | 
         
            +
                    image = remove_background(image, rembg_session)
         
     | 
| 52 | 
         
            +
                    image = resize_foreground(image, foreground_ratio)
         
     | 
| 53 | 
         
            +
                    image = fill_background(image)
         
     | 
| 54 | 
         
            +
                else:
         
     | 
| 55 | 
         
            +
                    image = input_image
         
     | 
| 56 | 
         
            +
                    if image.mode == "RGBA":
         
     | 
| 57 | 
         
            +
                        image = fill_background(image)
         
     | 
| 58 | 
         
            +
                return image
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
            def generate(image, mc_resolution, formats=["obj", "glb"]):
         
     | 
| 62 | 
         
            +
                print(image.shape, image.min(), image.max())
         
     | 
| 63 | 
         
            +
                scene_codes = model(image, device=device)
         
     | 
| 64 | 
         
            +
                mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
         
     | 
| 65 | 
         
            +
                mesh = to_gradio_3d_orientation(mesh)
         
     | 
| 66 | 
         
            +
                rv = []
         
     | 
| 67 | 
         
            +
                for format in formats:
         
     | 
| 68 | 
         
            +
                    mesh_path = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
         
     | 
| 69 | 
         
            +
                    mesh.export(mesh_path.name)
         
     | 
| 70 | 
         
            +
                    rv.append(mesh_path.name)
         
     | 
| 71 | 
         
            +
                return rv
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
            def run_example(image_pil):
         
     | 
| 75 | 
         
            +
                preprocessed = preprocess(image_pil, False, 0.9)
         
     | 
| 76 | 
         
            +
                mesh_name_obj, mesh_name_glb = generate(preprocessed, 256, ["obj", "glb"])
         
     | 
| 77 | 
         
            +
                return preprocessed, mesh_name_obj, mesh_name_glb
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
            with gr.Blocks(title="TripoSR") as interface:
         
     | 
| 81 | 
         
            +
                gr.Markdown(
         
     | 
| 82 | 
         
            +
                    """
         
     | 
| 83 | 
         
            +
                # TripoSR Demo
         
     | 
| 84 | 
         
            +
                [TripoSR](https://github.com/VAST-AI-Research/TripoSR) is a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, collaboratively developed by [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
         
     | 
| 85 | 
         
            +
                
         
     | 
| 86 | 
         
            +
                **Tips:**
         
     | 
| 87 | 
         
            +
                1. If you find the result is unsatisfied, please try to change the foreground ratio. It might improve the results.
         
     | 
| 88 | 
         
            +
                2. You can disable "Remove Background" for the provided examples since they have been already preprocessed.
         
     | 
| 89 | 
         
            +
                3. Otherwise, please disable "Remove Background" option only if your input image is RGBA with transparent background, image contents are centered and occupy more than 70% of image width or height.
         
     | 
| 90 | 
         
            +
                """
         
     | 
| 91 | 
         
            +
                )
         
     | 
| 92 | 
         
            +
                with gr.Row(variant="panel"):
         
     | 
| 93 | 
         
            +
                    with gr.Column():
         
     | 
| 94 | 
         
            +
                        with gr.Row():
         
     | 
| 95 | 
         
            +
                            input_image = gr.Image(
         
     | 
| 96 | 
         
            +
                                label="Input Image",
         
     | 
| 97 | 
         
            +
                                image_mode="RGBA",
         
     | 
| 98 | 
         
            +
                                sources="upload",
         
     | 
| 99 | 
         
            +
                                type="pil",
         
     | 
| 100 | 
         
            +
                                elem_id="content_image",
         
     | 
| 101 | 
         
            +
                            )
         
     | 
| 102 | 
         
            +
                            processed_image = gr.Image(label="Processed Image", interactive=False)
         
     | 
| 103 | 
         
            +
                        with gr.Row():
         
     | 
| 104 | 
         
            +
                            with gr.Group():
         
     | 
| 105 | 
         
            +
                                do_remove_background = gr.Checkbox(
         
     | 
| 106 | 
         
            +
                                    label="Remove Background", value=True
         
     | 
| 107 | 
         
            +
                                )
         
     | 
| 108 | 
         
            +
                                foreground_ratio = gr.Slider(
         
     | 
| 109 | 
         
            +
                                    label="Foreground Ratio",
         
     | 
| 110 | 
         
            +
                                    minimum=0.5,
         
     | 
| 111 | 
         
            +
                                    maximum=1.0,
         
     | 
| 112 | 
         
            +
                                    value=0.85,
         
     | 
| 113 | 
         
            +
                                    step=0.05,
         
     | 
| 114 | 
         
            +
                                )
         
     | 
| 115 | 
         
            +
                                mc_resolution = gr.Slider(
         
     | 
| 116 | 
         
            +
                                    label="Marching Cubes Resolution",
         
     | 
| 117 | 
         
            +
                                    minimum=32,
         
     | 
| 118 | 
         
            +
                                    maximum=320,
         
     | 
| 119 | 
         
            +
                                    value=256,
         
     | 
| 120 | 
         
            +
                                    step=32
         
     | 
| 121 | 
         
            +
                                )
         
     | 
| 122 | 
         
            +
                        with gr.Row():
         
     | 
| 123 | 
         
            +
                            submit = gr.Button("Generate", elem_id="generate", variant="primary")
         
     | 
| 124 | 
         
            +
                    with gr.Column():
         
     | 
| 125 | 
         
            +
                        with gr.Tab("OBJ"):
         
     | 
| 126 | 
         
            +
                            output_model_obj = gr.Model3D(
         
     | 
| 127 | 
         
            +
                                label="Output Model (OBJ Format)",
         
     | 
| 128 | 
         
            +
                                interactive=False,
         
     | 
| 129 | 
         
            +
                            )
         
     | 
| 130 | 
         
            +
                            gr.Markdown("Note: The model shown here is flipped. Download to get correct results.")
         
     | 
| 131 | 
         
            +
                        with gr.Tab("GLB"):
         
     | 
| 132 | 
         
            +
                            output_model_glb = gr.Model3D(
         
     | 
| 133 | 
         
            +
                                label="Output Model (GLB Format)",
         
     | 
| 134 | 
         
            +
                                interactive=False,
         
     | 
| 135 | 
         
            +
                            )
         
     | 
| 136 | 
         
            +
                            gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
         
     | 
| 137 | 
         
            +
                with gr.Row(variant="panel"):
         
     | 
| 138 | 
         
            +
                    gr.Examples(
         
     | 
| 139 | 
         
            +
                        examples=[
         
     | 
| 140 | 
         
            +
                            "examples/hamburger.png",
         
     | 
| 141 | 
         
            +
                            "examples/poly_fox.png",
         
     | 
| 142 | 
         
            +
                            "examples/robot.png",
         
     | 
| 143 | 
         
            +
                            "examples/teapot.png",
         
     | 
| 144 | 
         
            +
                            "examples/tiger_girl.png",
         
     | 
| 145 | 
         
            +
                            "examples/horse.png",
         
     | 
| 146 | 
         
            +
                            "examples/flamingo.png",
         
     | 
| 147 | 
         
            +
                            "examples/unicorn.png",
         
     | 
| 148 | 
         
            +
                            "examples/chair.png",
         
     | 
| 149 | 
         
            +
                            "examples/iso_house.png",
         
     | 
| 150 | 
         
            +
                            "examples/marble.png",
         
     | 
| 151 | 
         
            +
                            "examples/police_woman.png",
         
     | 
| 152 | 
         
            +
                            "examples/captured_p.png",
         
     | 
| 153 | 
         
            +
                        ],
         
     | 
| 154 | 
         
            +
                        inputs=[input_image],
         
     | 
| 155 | 
         
            +
                        outputs=[processed_image, output_model_obj, output_model_glb],
         
     | 
| 156 | 
         
            +
                        cache_examples=False,
         
     | 
| 157 | 
         
            +
                        fn=partial(run_example),
         
     | 
| 158 | 
         
            +
                        label="Examples",
         
     | 
| 159 | 
         
            +
                        examples_per_page=20,
         
     | 
| 160 | 
         
            +
                    )
         
     | 
| 161 | 
         
            +
                submit.click(fn=check_input_image, inputs=[input_image]).success(
         
     | 
| 162 | 
         
            +
                    fn=preprocess,
         
     | 
| 163 | 
         
            +
                    inputs=[input_image, do_remove_background, foreground_ratio],
         
     | 
| 164 | 
         
            +
                    outputs=[processed_image],
         
     | 
| 165 | 
         
            +
                ).success(
         
     | 
| 166 | 
         
            +
                    fn=generate,
         
     | 
| 167 | 
         
            +
                    inputs=[processed_image, mc_resolution],
         
     | 
| 168 | 
         
            +
                    outputs=[output_model_obj, output_model_glb],
         
     | 
| 169 | 
         
            +
                )
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 174 | 
         
            +
                parser = argparse.ArgumentParser()
         
     | 
| 175 | 
         
            +
                parser.add_argument('--username', type=str, default=None, help='Username for authentication')
         
     | 
| 176 | 
         
            +
                parser.add_argument('--password', type=str, default=None, help='Password for authentication')
         
     | 
| 177 | 
         
            +
                parser.add_argument('--port', type=int, default=7860, help='Port to run the server listener on')
         
     | 
| 178 | 
         
            +
                parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
         
     | 
| 179 | 
         
            +
                parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
         
     | 
| 180 | 
         
            +
                parser.add_argument("--queuesize", type=int, default=1, help="launch gradio queue max_size")
         
     | 
| 181 | 
         
            +
                args = parser.parse_args()
         
     | 
| 182 | 
         
            +
                interface.queue(max_size=args.queuesize)
         
     | 
| 183 | 
         
            +
                interface.launch(
         
     | 
| 184 | 
         
            +
                    auth=(args.username, args.password) if (args.username and args.password) else None,
         
     | 
| 185 | 
         
            +
                    share=args.share,
         
     | 
| 186 | 
         
            +
                    server_name="0.0.0.0" if args.listen else None, 
         
     | 
| 187 | 
         
            +
                    server_port=args.port
         
     | 
| 188 | 
         
            +
                )
         
     | 
    	
        output/.DS_Store
    ADDED
    
    | 
         Binary file (6.15 kB). View file 
     | 
| 
         | 
    	
        output/0/input.png
    ADDED
    
    
											 
									 | 
									
								
    	
        output/0/mesh.obj
    ADDED
    
    | 
         The diff for this file is too large to render. 
		See raw diff 
     | 
| 
         | 
    	
        requirements.txt
    ADDED
    
    | 
         @@ -0,0 +1,17 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            diffusers==0.24.0
         
     | 
| 2 | 
         
            +
            einops==0.7.0
         
     | 
| 3 | 
         
            +
            gradio==4.20.1
         
     | 
| 4 | 
         
            +
            huggingface_hub==0.21.4
         
     | 
| 5 | 
         
            +
            imageio==2.27.0
         
     | 
| 6 | 
         
            +
            numpy==1.24.3
         
     | 
| 7 | 
         
            +
            omegaconf==2.3.0
         
     | 
| 8 | 
         
            +
            packaging==23.2
         
     | 
| 9 | 
         
            +
            Pillow==10.1.0
         
     | 
| 10 | 
         
            +
            rembg==2.0.55
         
     | 
| 11 | 
         
            +
            safetensors==0.3.2
         
     | 
| 12 | 
         
            +
            torch==2.0.0
         
     | 
| 13 | 
         
            +
            torchvision==0.15.1
         
     | 
| 14 | 
         
            +
            tqdm==4.64.1
         
     | 
| 15 | 
         
            +
            transformers==4.27.0
         
     | 
| 16 | 
         
            +
            trimesh==4.0.5
         
     | 
| 17 | 
         
            +
            git+https://github.com/tatsy/torchmcubes.git
         
     | 
    	
        requirements2.txt
    ADDED
    
    | 
         @@ -0,0 +1,9 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            omegaconf==2.3.0
         
     | 
| 2 | 
         
            +
            Pillow==10.1.0
         
     | 
| 3 | 
         
            +
            einops==0.7.0
         
     | 
| 4 | 
         
            +
            git+https://github.com/tatsy/torchmcubes.git
         
     | 
| 5 | 
         
            +
            transformers==4.35.0
         
     | 
| 6 | 
         
            +
            trimesh==4.0.5
         
     | 
| 7 | 
         
            +
            rembg
         
     | 
| 8 | 
         
            +
            huggingface-hub
         
     | 
| 9 | 
         
            +
            imageio[ffmpeg]
         
     | 
    	
        run.py
    ADDED
    
    | 
         @@ -0,0 +1,162 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import argparse
         
     | 
| 2 | 
         
            +
            import logging
         
     | 
| 3 | 
         
            +
            import os
         
     | 
| 4 | 
         
            +
            import time
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            import rembg
         
     | 
| 8 | 
         
            +
            import torch
         
     | 
| 9 | 
         
            +
            from PIL import Image
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            from tsr.system import TSR
         
     | 
| 12 | 
         
            +
            from tsr.utils import remove_background, resize_foreground, save_video
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            class Timer:
         
     | 
| 16 | 
         
            +
                def __init__(self):
         
     | 
| 17 | 
         
            +
                    self.items = {}
         
     | 
| 18 | 
         
            +
                    self.time_scale = 1000.0  # ms
         
     | 
| 19 | 
         
            +
                    self.time_unit = "ms"
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                def start(self, name: str) -> None:
         
     | 
| 22 | 
         
            +
                    if torch.cuda.is_available():
         
     | 
| 23 | 
         
            +
                        torch.cuda.synchronize()
         
     | 
| 24 | 
         
            +
                    self.items[name] = time.time()
         
     | 
| 25 | 
         
            +
                    logging.info(f"{name} ...")
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                def end(self, name: str) -> float:
         
     | 
| 28 | 
         
            +
                    if name not in self.items:
         
     | 
| 29 | 
         
            +
                        return
         
     | 
| 30 | 
         
            +
                    if torch.cuda.is_available():
         
     | 
| 31 | 
         
            +
                        torch.cuda.synchronize()
         
     | 
| 32 | 
         
            +
                    start_time = self.items.pop(name)
         
     | 
| 33 | 
         
            +
                    delta = time.time() - start_time
         
     | 
| 34 | 
         
            +
                    t = delta * self.time_scale
         
     | 
| 35 | 
         
            +
                    logging.info(f"{name} finished in {t:.2f}{self.time_unit}.")
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            timer = Timer()
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            logging.basicConfig(
         
     | 
| 42 | 
         
            +
                format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
         
     | 
| 43 | 
         
            +
            )
         
     | 
| 44 | 
         
            +
            parser = argparse.ArgumentParser()
         
     | 
| 45 | 
         
            +
            parser.add_argument("image", type=str, nargs="+", help="Path to input image(s).")
         
     | 
| 46 | 
         
            +
            parser.add_argument(
         
     | 
| 47 | 
         
            +
                "--device",
         
     | 
| 48 | 
         
            +
                default="cuda:0",
         
     | 
| 49 | 
         
            +
                type=str,
         
     | 
| 50 | 
         
            +
                help="Device to use. If no CUDA-compatible device is found, will fallback to 'cpu'. Default: 'cuda:0'",
         
     | 
| 51 | 
         
            +
            )
         
     | 
| 52 | 
         
            +
            parser.add_argument(
         
     | 
| 53 | 
         
            +
                "--pretrained-model-name-or-path",
         
     | 
| 54 | 
         
            +
                default="stabilityai/TripoSR",
         
     | 
| 55 | 
         
            +
                type=str,
         
     | 
| 56 | 
         
            +
                help="Path to the pretrained model. Could be either a huggingface model id is or a local path. Default: 'stabilityai/TripoSR'",
         
     | 
| 57 | 
         
            +
            )
         
     | 
| 58 | 
         
            +
            parser.add_argument(
         
     | 
| 59 | 
         
            +
                "--chunk-size",
         
     | 
| 60 | 
         
            +
                default=8192,
         
     | 
| 61 | 
         
            +
                type=int,
         
     | 
| 62 | 
         
            +
                help="Evaluation chunk size for surface extraction and rendering. Smaller chunk size reduces VRAM usage but increases computation time. 0 for no chunking. Default: 8192",
         
     | 
| 63 | 
         
            +
            )
         
     | 
| 64 | 
         
            +
            parser.add_argument(
         
     | 
| 65 | 
         
            +
                "--mc-resolution",
         
     | 
| 66 | 
         
            +
                default=256,
         
     | 
| 67 | 
         
            +
                type=int,
         
     | 
| 68 | 
         
            +
                help="Marching cubes grid resolution. Default: 256"
         
     | 
| 69 | 
         
            +
            )
         
     | 
| 70 | 
         
            +
            parser.add_argument(
         
     | 
| 71 | 
         
            +
                "--no-remove-bg",
         
     | 
| 72 | 
         
            +
                action="store_true",
         
     | 
| 73 | 
         
            +
                help="If specified, the background will NOT be automatically removed from the input image, and the input image should be an RGB image with gray background and properly-sized foreground. Default: false",
         
     | 
| 74 | 
         
            +
            )
         
     | 
| 75 | 
         
            +
            parser.add_argument(
         
     | 
| 76 | 
         
            +
                "--foreground-ratio",
         
     | 
| 77 | 
         
            +
                default=0.85,
         
     | 
| 78 | 
         
            +
                type=float,
         
     | 
| 79 | 
         
            +
                help="Ratio of the foreground size to the image size. Only used when --no-remove-bg is not specified. Default: 0.85",
         
     | 
| 80 | 
         
            +
            )
         
     | 
| 81 | 
         
            +
            parser.add_argument(
         
     | 
| 82 | 
         
            +
                "--output-dir",
         
     | 
| 83 | 
         
            +
                default="output/",
         
     | 
| 84 | 
         
            +
                type=str,
         
     | 
| 85 | 
         
            +
                help="Output directory to save the results. Default: 'output/'",
         
     | 
| 86 | 
         
            +
            )
         
     | 
| 87 | 
         
            +
            parser.add_argument(
         
     | 
| 88 | 
         
            +
                "--model-save-format",
         
     | 
| 89 | 
         
            +
                default="obj",
         
     | 
| 90 | 
         
            +
                type=str,
         
     | 
| 91 | 
         
            +
                choices=["obj", "glb"],
         
     | 
| 92 | 
         
            +
                help="Format to save the extracted mesh. Default: 'obj'",
         
     | 
| 93 | 
         
            +
            )
         
     | 
| 94 | 
         
            +
            parser.add_argument(
         
     | 
| 95 | 
         
            +
                "--render",
         
     | 
| 96 | 
         
            +
                action="store_true",
         
     | 
| 97 | 
         
            +
                help="If specified, save a NeRF-rendered video. Default: false",
         
     | 
| 98 | 
         
            +
            )
         
     | 
| 99 | 
         
            +
            args = parser.parse_args()
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
            output_dir = args.output_dir
         
     | 
| 102 | 
         
            +
            os.makedirs(output_dir, exist_ok=True)
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
            device = args.device
         
     | 
| 105 | 
         
            +
            if not torch.cuda.is_available():
         
     | 
| 106 | 
         
            +
                device = "cpu"
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
            timer.start("Initializing model")
         
     | 
| 109 | 
         
            +
            model = TSR.from_pretrained(
         
     | 
| 110 | 
         
            +
                args.pretrained_model_name_or_path,
         
     | 
| 111 | 
         
            +
                config_name="config.yaml",
         
     | 
| 112 | 
         
            +
                weight_name="model.ckpt",
         
     | 
| 113 | 
         
            +
            )
         
     | 
| 114 | 
         
            +
            model.renderer.set_chunk_size(args.chunk_size)
         
     | 
| 115 | 
         
            +
            model.to(device)
         
     | 
| 116 | 
         
            +
            timer.end("Initializing model")
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
            timer.start("Processing images")
         
     | 
| 119 | 
         
            +
            images = []
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
            if args.no_remove_bg:
         
     | 
| 122 | 
         
            +
                rembg_session = None
         
     | 
| 123 | 
         
            +
            else:
         
     | 
| 124 | 
         
            +
                rembg_session = rembg.new_session()
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
            for i, image_path in enumerate(args.image):
         
     | 
| 127 | 
         
            +
                if args.no_remove_bg:
         
     | 
| 128 | 
         
            +
                    image = np.array(Image.open(image_path).convert("RGB"))
         
     | 
| 129 | 
         
            +
                else:
         
     | 
| 130 | 
         
            +
                    image = remove_background(Image.open(image_path), rembg_session)
         
     | 
| 131 | 
         
            +
                    image = resize_foreground(image, args.foreground_ratio)
         
     | 
| 132 | 
         
            +
                    image = np.array(image).astype(np.float32) / 255.0
         
     | 
| 133 | 
         
            +
                    image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
         
     | 
| 134 | 
         
            +
                    image = Image.fromarray((image * 255.0).astype(np.uint8))
         
     | 
| 135 | 
         
            +
                    if not os.path.exists(os.path.join(output_dir, str(i))):
         
     | 
| 136 | 
         
            +
                        os.makedirs(os.path.join(output_dir, str(i)))
         
     | 
| 137 | 
         
            +
                    image.save(os.path.join(output_dir, str(i), f"input.png"))
         
     | 
| 138 | 
         
            +
                images.append(image)
         
     | 
| 139 | 
         
            +
            timer.end("Processing images")
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
            for i, image in enumerate(images):
         
     | 
| 142 | 
         
            +
                logging.info(f"Running image {i + 1}/{len(images)} ...")
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                timer.start("Running model")
         
     | 
| 145 | 
         
            +
                with torch.no_grad():
         
     | 
| 146 | 
         
            +
                    scene_codes = model([image], device=device)
         
     | 
| 147 | 
         
            +
                timer.end("Running model")
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                if args.render:
         
     | 
| 150 | 
         
            +
                    timer.start("Rendering")
         
     | 
| 151 | 
         
            +
                    render_images = model.render(scene_codes, n_views=30, return_type="pil")
         
     | 
| 152 | 
         
            +
                    for ri, render_image in enumerate(render_images[0]):
         
     | 
| 153 | 
         
            +
                        render_image.save(os.path.join(output_dir, str(i), f"render_{ri:03d}.png"))
         
     | 
| 154 | 
         
            +
                    save_video(
         
     | 
| 155 | 
         
            +
                        render_images[0], os.path.join(output_dir, str(i), f"render.mp4"), fps=30
         
     | 
| 156 | 
         
            +
                    )
         
     | 
| 157 | 
         
            +
                    timer.end("Rendering")
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                timer.start("Exporting mesh")
         
     | 
| 160 | 
         
            +
                meshes = model.extract_mesh(scene_codes, resolution=args.mc_resolution)
         
     | 
| 161 | 
         
            +
                meshes[0].export(os.path.join(output_dir, str(i), f"mesh.{args.model_save_format}"))
         
     | 
| 162 | 
         
            +
                timer.end("Exporting mesh")
         
     | 
    	
        src/__pycache__/__init__.cpython-38.pyc
    ADDED
    
    | 
         Binary file (147 Bytes). View file 
     | 
| 
         | 
    	
        src/__pycache__/scheduler_perflow.cpython-310.pyc
    ADDED
    
    | 
         Binary file (12.2 kB). View file 
     | 
| 
         | 
    	
        src/__pycache__/scheduler_perflow.cpython-38.pyc
    ADDED
    
    | 
         Binary file (12.1 kB). View file 
     | 
| 
         | 
    	
        src/__pycache__/utils_perflow.cpython-38.pyc
    ADDED
    
    | 
         Binary file (2.64 kB). View file 
     | 
| 
         | 
    	
        src/laion_bytenas.py
    ADDED
    
    | 
         @@ -0,0 +1,257 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import json
         
     | 
| 3 | 
         
            +
            import random
         
     | 
| 4 | 
         
            +
            from tqdm import tqdm
         
     | 
| 5 | 
         
            +
            import numpy as np
         
     | 
| 6 | 
         
            +
            from PIL import Image, ImageStat
         
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            from torch.utils.data import Dataset, DataLoader, IterableDataset, get_worker_info
         
     | 
| 9 | 
         
            +
            from torchvision import transforms as T
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            ### >>>>>>>> >>>>>>>> text related >>>>>>>> >>>>>>>> ###
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            class TokenizerWrapper():
         
     | 
| 15 | 
         
            +
                def __init__(self, tokenizer, is_train, proportion_empty_prompts, use_generic_prompts=False):
         
     | 
| 16 | 
         
            +
                    self.tokenizer = tokenizer
         
     | 
| 17 | 
         
            +
                    self.is_train = is_train
         
     | 
| 18 | 
         
            +
                    self.proportion_empty_prompts = proportion_empty_prompts
         
     | 
| 19 | 
         
            +
                    self.use_generic_prompts = use_generic_prompts
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                def __call__(self, prompts):
         
     | 
| 22 | 
         
            +
                    if isinstance(prompts, str):
         
     | 
| 23 | 
         
            +
                        prompts = [prompts]
         
     | 
| 24 | 
         
            +
                    captions = []
         
     | 
| 25 | 
         
            +
                    for caption in prompts:
         
     | 
| 26 | 
         
            +
                        if random.random() < self.proportion_empty_prompts:
         
     | 
| 27 | 
         
            +
                            captions.append("")
         
     | 
| 28 | 
         
            +
                        else:
         
     | 
| 29 | 
         
            +
                            if self.use_generic_prompts:
         
     | 
| 30 | 
         
            +
                                captions.append("best quality, high quality")
         
     | 
| 31 | 
         
            +
                            elif isinstance(caption, str):
         
     | 
| 32 | 
         
            +
                                captions.append(caption)
         
     | 
| 33 | 
         
            +
                            elif isinstance(caption, (list, np.ndarray)):
         
     | 
| 34 | 
         
            +
                                # take a random caption if there are multiple
         
     | 
| 35 | 
         
            +
                                captions.append(random.choice(caption) if self.is_train else caption[0])
         
     | 
| 36 | 
         
            +
                            else:
         
     | 
| 37 | 
         
            +
                                raise ValueError(
         
     | 
| 38 | 
         
            +
                                    f"Caption column should contain either strings or lists of strings."
         
     | 
| 39 | 
         
            +
                                )
         
     | 
| 40 | 
         
            +
                    inputs = self.tokenizer(
         
     | 
| 41 | 
         
            +
                        captions, max_length=self.tokenizer.model_max_length, padding="max_length",
         
     | 
| 42 | 
         
            +
                        truncation=True, return_tensors="pt"
         
     | 
| 43 | 
         
            +
                    )
         
     | 
| 44 | 
         
            +
                    return inputs.input_ids
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            ### >>>>>>>> >>>>>>>> image related >>>>>>>> >>>>>>>> ###
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            MONOCHROMATIC_MAX_VARIANCE = 0.3
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            def is_monochromatic_image(pil_img):
         
     | 
| 53 | 
         
            +
                v = ImageStat.Stat(pil_img.convert('RGB')).var
         
     | 
| 54 | 
         
            +
                return sum(v)<MONOCHROMATIC_MAX_VARIANCE
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
            def isnumeric(text):
         
     | 
| 57 | 
         
            +
                return (''.join(filter(str.isalnum, text))).isnumeric()
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
            class TextPromptDataset(IterableDataset):
         
     | 
| 62 | 
         
            +
                '''
         
     | 
| 63 | 
         
            +
                  The dataset for (text embedding, noise, generated latent) triplets.
         
     | 
| 64 | 
         
            +
                '''
         
     | 
| 65 | 
         
            +
                def __init__(self, 
         
     | 
| 66 | 
         
            +
                            data_root, 
         
     | 
| 67 | 
         
            +
                            tokenizer = None,
         
     | 
| 68 | 
         
            +
                            transform = None,
         
     | 
| 69 | 
         
            +
                            rank = 0,
         
     | 
| 70 | 
         
            +
                            world_size = 1,
         
     | 
| 71 | 
         
            +
                            shuffle = True,
         
     | 
| 72 | 
         
            +
                ):
         
     | 
| 73 | 
         
            +
                    self.tokenizer = tokenizer
         
     | 
| 74 | 
         
            +
                    self.transform = transform
         
     | 
| 75 | 
         
            +
                    
         
     | 
| 76 | 
         
            +
                    self.img_root = os.path.join(data_root, 'JPEGImages')
         
     | 
| 77 | 
         
            +
                    self.data_list = []
         
     | 
| 78 | 
         
            +
                    
         
     | 
| 79 | 
         
            +
                    print("#### Loading filename list...")
         
     | 
| 80 | 
         
            +
                    json_root = os.path.join(data_root, 'list')
         
     | 
| 81 | 
         
            +
                    json_list = [p for p in os.listdir(json_root) if p.startswith("shard") and p.endswith('.json')]
         
     | 
| 82 | 
         
            +
                    
         
     | 
| 83 | 
         
            +
                    # duplicate several shards to make sure each process has the same number of shards
         
     | 
| 84 | 
         
            +
                    assert len(json_list) > world_size
         
     | 
| 85 | 
         
            +
                    duplicate = world_size - len(json_list)%world_size if len(json_list)%world_size>0 else 0
         
     | 
| 86 | 
         
            +
                    json_list = json_list + json_list[:duplicate]
         
     | 
| 87 | 
         
            +
                    json_list = json_list[rank::world_size]
         
     | 
| 88 | 
         
            +
                    
         
     | 
| 89 | 
         
            +
                    for json_file in tqdm(json_list):
         
     | 
| 90 | 
         
            +
                        shard_name = os.path.basename(json_file).split('.')[0]
         
     | 
| 91 | 
         
            +
                        with open(os.path.join(json_root, json_file)) as f:
         
     | 
| 92 | 
         
            +
                            key_text_pairs = json.load(f)
         
     | 
| 93 | 
         
            +
                            
         
     | 
| 94 | 
         
            +
                        for pair in key_text_pairs:
         
     | 
| 95 | 
         
            +
                            self.data_list.append( [shard_name] + pair )
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                    print("#### All filename loaded...")
         
     | 
| 98 | 
         
            +
                    
         
     | 
| 99 | 
         
            +
                    self.shuffle = shuffle
         
     | 
| 100 | 
         
            +
                    
         
     | 
| 101 | 
         
            +
                def __len__(self):
         
     | 
| 102 | 
         
            +
                    return len(self.data_list)
         
     | 
| 103 | 
         
            +
                
         
     | 
| 104 | 
         
            +
                
         
     | 
| 105 | 
         
            +
                def __iter__(self):
         
     | 
| 106 | 
         
            +
                    worker_info = get_worker_info()
         
     | 
| 107 | 
         
            +
                    
         
     | 
| 108 | 
         
            +
                    if worker_info is None:  # single-process data loading, return the full iterator
         
     | 
| 109 | 
         
            +
                        data_list = self.data_list
         
     | 
| 110 | 
         
            +
                    else:
         
     | 
| 111 | 
         
            +
                        len_data = len(self.data_list) - len(self.data_list) % worker_info.num_workers
         
     | 
| 112 | 
         
            +
                        data_list = self.data_list[:len_data][worker_info.id :: worker_info.num_workers]
         
     | 
| 113 | 
         
            +
                        # print(worker_info.num_workers, worker_info.id, len(data_list)/len(self.data_list))
         
     | 
| 114 | 
         
            +
                        
         
     | 
| 115 | 
         
            +
                    if self.shuffle:
         
     | 
| 116 | 
         
            +
                        random.shuffle(data_list) 
         
     | 
| 117 | 
         
            +
                        
         
     | 
| 118 | 
         
            +
                    while True:    
         
     | 
| 119 | 
         
            +
                        for idx in range(len(data_list)):
         
     | 
| 120 | 
         
            +
                            # try:
         
     | 
| 121 | 
         
            +
                            shard_name = data_list[idx][0]
         
     | 
| 122 | 
         
            +
                            data = {}
         
     | 
| 123 | 
         
            +
                            
         
     | 
| 124 | 
         
            +
                            img_file = data_list[idx][1]
         
     | 
| 125 | 
         
            +
                            img = Image.open(os.path.join(self.img_root, shard_name, img_file+'.jpg')).convert("RGB")
         
     | 
| 126 | 
         
            +
                            
         
     | 
| 127 | 
         
            +
                            if is_monochromatic_image(img):
         
     | 
| 128 | 
         
            +
                                continue
         
     | 
| 129 | 
         
            +
                            
         
     | 
| 130 | 
         
            +
                            if self.transform is not None:
         
     | 
| 131 | 
         
            +
                                img = self.transform(img)
         
     | 
| 132 | 
         
            +
                                
         
     | 
| 133 | 
         
            +
                            data['pixel_values'] = img
         
     | 
| 134 | 
         
            +
                            
         
     | 
| 135 | 
         
            +
                            text = data_list[idx][2]
         
     | 
| 136 | 
         
            +
                            if self.tokenizer is not None:
         
     | 
| 137 | 
         
            +
                                if isinstance(self.tokenizer, list):
         
     | 
| 138 | 
         
            +
                                    assert len(self.tokenizer)==2
         
     | 
| 139 | 
         
            +
                                    data['input_ids'] = self.tokenizer[0](text)[0]
         
     | 
| 140 | 
         
            +
                                    data['input_ids_2'] = self.tokenizer[1](text)[0]
         
     | 
| 141 | 
         
            +
                                else:
         
     | 
| 142 | 
         
            +
                                    data['input_ids'] = self.tokenizer(text)[0]
         
     | 
| 143 | 
         
            +
                            else:
         
     | 
| 144 | 
         
            +
                                data['input_ids'] = text
         
     | 
| 145 | 
         
            +
                            
         
     | 
| 146 | 
         
            +
                            yield data
         
     | 
| 147 | 
         
            +
                            
         
     | 
| 148 | 
         
            +
                            # except Exception as e:
         
     | 
| 149 | 
         
            +
                            #     raise(e)
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                def collate_fn(self, examples):
         
     | 
| 152 | 
         
            +
                    pixel_values = torch.stack([example["pixel_values"] for example in examples])
         
     | 
| 153 | 
         
            +
                    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
         
     | 
| 154 | 
         
            +
                    
         
     | 
| 155 | 
         
            +
                    if self.tokenizer is not None:
         
     | 
| 156 | 
         
            +
                        if isinstance(self.tokenizer, list):
         
     | 
| 157 | 
         
            +
                            assert len(self.tokenizer)==2
         
     | 
| 158 | 
         
            +
                            input_ids = torch.stack([example["input_ids"] for example in examples])
         
     | 
| 159 | 
         
            +
                            input_ids_2 = torch.stack([example["input_ids_2"] for example in examples])
         
     | 
| 160 | 
         
            +
                            return {"pixel_values": pixel_values, "input_ids": input_ids, "input_ids_2": input_ids_2,}
         
     | 
| 161 | 
         
            +
                        else:
         
     | 
| 162 | 
         
            +
                            input_ids = torch.stack([example["input_ids"] for example in examples])
         
     | 
| 163 | 
         
            +
                            return {"pixel_values": pixel_values, "input_ids": input_ids,}
         
     | 
| 164 | 
         
            +
                    else:
         
     | 
| 165 | 
         
            +
                        input_ids = [example["input_ids"] for example in examples]
         
     | 
| 166 | 
         
            +
                        return {"pixel_values": pixel_values, "input_ids": input_ids,}
         
     | 
| 167 | 
         
            +
                
         
     | 
| 168 | 
         
            +
                
         
     | 
| 169 | 
         
            +
            def make_train_dataset(
         
     | 
| 170 | 
         
            +
                    train_data_path, 
         
     | 
| 171 | 
         
            +
                    size = 512,
         
     | 
| 172 | 
         
            +
                    tokenizer=None, 
         
     | 
| 173 | 
         
            +
                    cfg_drop_ratio=0,
         
     | 
| 174 | 
         
            +
                    rank=0, 
         
     | 
| 175 | 
         
            +
                    world_size=1,
         
     | 
| 176 | 
         
            +
                    shuffle=True,
         
     | 
| 177 | 
         
            +
                ):
         
     | 
| 178 | 
         
            +
                
         
     | 
| 179 | 
         
            +
                _image_transform = T.Compose([
         
     | 
| 180 | 
         
            +
                        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
         
     | 
| 181 | 
         
            +
                        T.Resize(size),
         
     | 
| 182 | 
         
            +
                        T.CenterCrop((size,size)),
         
     | 
| 183 | 
         
            +
                        T.ToTensor(),
         
     | 
| 184 | 
         
            +
                        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
         
     | 
| 185 | 
         
            +
                    ])
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                if tokenizer is not None:
         
     | 
| 188 | 
         
            +
                    if isinstance(tokenizer, list):
         
     | 
| 189 | 
         
            +
                        assert len(tokenizer)==2
         
     | 
| 190 | 
         
            +
                        tokenizer_1 = TokenizerWrapper(
         
     | 
| 191 | 
         
            +
                            tokenizer[0], 
         
     | 
| 192 | 
         
            +
                            is_train=True, 
         
     | 
| 193 | 
         
            +
                            proportion_empty_prompts=cfg_drop_ratio,
         
     | 
| 194 | 
         
            +
                            use_generic_prompts=False,
         
     | 
| 195 | 
         
            +
                        )
         
     | 
| 196 | 
         
            +
                        tokenizer_2 = TokenizerWrapper(
         
     | 
| 197 | 
         
            +
                            tokenizer[1], 
         
     | 
| 198 | 
         
            +
                            is_train=True, 
         
     | 
| 199 | 
         
            +
                            proportion_empty_prompts=cfg_drop_ratio,
         
     | 
| 200 | 
         
            +
                            use_generic_prompts=False,
         
     | 
| 201 | 
         
            +
                        )
         
     | 
| 202 | 
         
            +
                        tokenizer = [tokenizer_1, tokenizer_2]
         
     | 
| 203 | 
         
            +
                        
         
     | 
| 204 | 
         
            +
                    else:
         
     | 
| 205 | 
         
            +
                        tokenizer = TokenizerWrapper(
         
     | 
| 206 | 
         
            +
                            tokenizer, 
         
     | 
| 207 | 
         
            +
                            is_train=True, 
         
     | 
| 208 | 
         
            +
                            proportion_empty_prompts=cfg_drop_ratio,
         
     | 
| 209 | 
         
            +
                            use_generic_prompts=False,
         
     | 
| 210 | 
         
            +
                        )
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                    
         
     | 
| 213 | 
         
            +
                train_dataset = TextPromptDataset(
         
     | 
| 214 | 
         
            +
                    data_root=train_data_path,
         
     | 
| 215 | 
         
            +
                    transform=_image_transform,
         
     | 
| 216 | 
         
            +
                    rank=rank,
         
     | 
| 217 | 
         
            +
                    world_size=world_size,
         
     | 
| 218 | 
         
            +
                    tokenizer=tokenizer,
         
     | 
| 219 | 
         
            +
                    shuffle=shuffle,
         
     | 
| 220 | 
         
            +
                )
         
     | 
| 221 | 
         
            +
                return train_dataset
         
     | 
| 222 | 
         
            +
                
         
     | 
| 223 | 
         
            +
                
         
     | 
| 224 | 
         
            +
                
         
     | 
| 225 | 
         
            +
                
         
     | 
| 226 | 
         
            +
                
         
     | 
| 227 | 
         
            +
                
         
     | 
| 228 | 
         
            +
                
         
     | 
| 229 | 
         
            +
                
         
     | 
| 230 | 
         
            +
                
         
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
            ### >>>>>>>> >>>>>>>> Test >>>>>>>> >>>>>>>> ###
         
     | 
| 233 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 234 | 
         
            +
                from transformers import CLIPTextModel, CLIPTokenizer
         
     | 
| 235 | 
         
            +
                tokenizer = CLIPTokenizer.from_pretrained(
         
     | 
| 236 | 
         
            +
                    "/mnt/bn/ic-research-aigc-editing/fast-diffusion-models/assets/public_models/StableDiffusion/stable-diffusion-v1-5",
         
     | 
| 237 | 
         
            +
                    subfolder="tokenizer"
         
     | 
| 238 | 
         
            +
                )
         
     | 
| 239 | 
         
            +
                train_dataset = make_train_dataset(tokenizer=tokenizer, rank=0, world_size=10)
         
     | 
| 240 | 
         
            +
                
         
     | 
| 241 | 
         
            +
                loader = torch.utils.data.DataLoader(
         
     | 
| 242 | 
         
            +
                    train_dataset, batch_size=64, num_workers=0, 
         
     | 
| 243 | 
         
            +
                    collate_fn=train_dataset.collect_fn if hasattr(train_dataset, 'collect_fn') else None,
         
     | 
| 244 | 
         
            +
                )
         
     | 
| 245 | 
         
            +
                for batch in loader:
         
     | 
| 246 | 
         
            +
                    pixel_values = batch["pixel_values"]
         
     | 
| 247 | 
         
            +
                    prompt_ids = batch['input_ids']
         
     | 
| 248 | 
         
            +
                    from einops import rearrange
         
     | 
| 249 | 
         
            +
                    pixel_values = rearrange(pixel_values, 'b c h w -> b h w c')
         
     | 
| 250 | 
         
            +
                    
         
     | 
| 251 | 
         
            +
                    for i in range(pixel_values.shape[0]):
         
     | 
| 252 | 
         
            +
                        import pdb; pdb.set_trace()
         
     | 
| 253 | 
         
            +
                        Image.fromarray(((pixel_values[i] + 1 )/2 * 255 ).numpy().astype(np.uint8)).save('tmp.png')
         
     | 
| 254 | 
         
            +
                        input_id = prompt_ids[i]
         
     | 
| 255 | 
         
            +
                        text = tokenizer.decode(input_id).split('<|startoftext|>')[-1].split('<|endoftext|>')[0]
         
     | 
| 256 | 
         
            +
                        print(text)
         
     | 
| 257 | 
         
            +
                    pass
         
     | 
    	
        src/pfode_solver.py
    ADDED
    
    | 
         @@ -0,0 +1,120 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os, math, random, argparse, logging
         
     | 
| 2 | 
         
            +
            from pathlib import Path
         
     | 
| 3 | 
         
            +
            from typing import Optional, Union, List, Callable
         
     | 
| 4 | 
         
            +
            from collections import OrderedDict
         
     | 
| 5 | 
         
            +
            from packaging import version
         
     | 
| 6 | 
         
            +
            from tqdm.auto import tqdm
         
     | 
| 7 | 
         
            +
            from omegaconf import OmegaConf
         
     | 
| 8 | 
         
            +
                    
         
     | 
| 9 | 
         
            +
            import numpy as np
         
     | 
| 10 | 
         
            +
            import torch
         
     | 
| 11 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 12 | 
         
            +
            import torch.utils.checkpoint
         
     | 
| 13 | 
         
            +
            import torchvision
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            class PFODESolver():
         
     | 
| 17 | 
         
            +
                def __init__(self, scheduler, t_initial=1, t_terminal=0,) -> None:
         
     | 
| 18 | 
         
            +
                    self.t_initial = t_initial
         
     | 
| 19 | 
         
            +
                    self.t_terminal = t_terminal
         
     | 
| 20 | 
         
            +
                    self.scheduler = scheduler
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                    train_step_terminal = 0 
         
     | 
| 23 | 
         
            +
                    train_step_initial = train_step_terminal + self.scheduler.config.num_train_timesteps # 0+1000
         
     | 
| 24 | 
         
            +
                    self.stepsize  = (t_terminal-t_initial) / (train_step_terminal - train_step_initial) #1/1000
         
     | 
| 25 | 
         
            +
                
         
     | 
| 26 | 
         
            +
                def get_timesteps(self, t_start, t_end, num_steps):
         
     | 
| 27 | 
         
            +
                    # (b,) -> (b,1)
         
     | 
| 28 | 
         
            +
                    t_start = t_start[:, None]
         
     | 
| 29 | 
         
            +
                    t_end = t_end[:, None]
         
     | 
| 30 | 
         
            +
                    assert t_start.dim() == 2
         
     | 
| 31 | 
         
            +
                    
         
     | 
| 32 | 
         
            +
                    timepoints = torch.arange(0, num_steps, 1).expand(t_start.shape[0], num_steps).to(device=t_start.device)
         
     | 
| 33 | 
         
            +
                    interval = (t_end - t_start) / (torch.ones([1], device=t_start.device) * num_steps)
         
     | 
| 34 | 
         
            +
                    timepoints = t_start + interval * timepoints
         
     | 
| 35 | 
         
            +
                    
         
     | 
| 36 | 
         
            +
                    timesteps = (self.scheduler.num_train_timesteps - 1) + (timepoints - self.t_initial) / self.stepsize # correspondint to StableDiffusion indexing system, from 999 (t_init) -> 0 (dt)
         
     | 
| 37 | 
         
            +
                    return timesteps.round().long()
         
     | 
| 38 | 
         
            +
                
         
     | 
| 39 | 
         
            +
                def solve(self, 
         
     | 
| 40 | 
         
            +
                          latents, 
         
     | 
| 41 | 
         
            +
                          unet, 
         
     | 
| 42 | 
         
            +
                          t_start, 
         
     | 
| 43 | 
         
            +
                          t_end, 
         
     | 
| 44 | 
         
            +
                          prompt_embeds, 
         
     | 
| 45 | 
         
            +
                          negative_prompt_embeds, 
         
     | 
| 46 | 
         
            +
                          guidance_scale=1.0,
         
     | 
| 47 | 
         
            +
                          num_steps = 2,
         
     | 
| 48 | 
         
            +
                          num_windows = 1,
         
     | 
| 49 | 
         
            +
                ):
         
     | 
| 50 | 
         
            +
                    assert t_start.dim() == 1
         
     | 
| 51 | 
         
            +
                    assert guidance_scale >= 1 and torch.all(torch.gt(t_start, t_end))
         
     | 
| 52 | 
         
            +
                    
         
     | 
| 53 | 
         
            +
                    do_classifier_free_guidance = True if guidance_scale > 1 else False
         
     | 
| 54 | 
         
            +
                    bsz = latents.shape[0]
         
     | 
| 55 | 
         
            +
                        
         
     | 
| 56 | 
         
            +
                    if do_classifier_free_guidance:
         
     | 
| 57 | 
         
            +
                        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
         
     | 
| 58 | 
         
            +
                        
         
     | 
| 59 | 
         
            +
                    timestep_cond = None
         
     | 
| 60 | 
         
            +
                    if unet.config.time_cond_proj_dim is not None:
         
     | 
| 61 | 
         
            +
                        guidance_scale_tensor = torch.tensor(guidance_scale - 1).repeat(bsz)
         
     | 
| 62 | 
         
            +
                        timestep_cond = self.get_guidance_scale_embedding(
         
     | 
| 63 | 
         
            +
                            guidance_scale_tensor, embedding_dim=unet.config.time_cond_proj_dim
         
     | 
| 64 | 
         
            +
                        ).to(device=latents.device, dtype=latents.dtype)
         
     | 
| 65 | 
         
            +
                        
         
     | 
| 66 | 
         
            +
                    
         
     | 
| 67 | 
         
            +
                    timesteps = self.get_timesteps(t_start, t_end, num_steps).to(device=latents.device)
         
     | 
| 68 | 
         
            +
                    timestep_interval = self.scheduler.config.num_train_timesteps // (num_windows * num_steps)
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                    # Denoising loop
         
     | 
| 71 | 
         
            +
                    with torch.no_grad():
         
     | 
| 72 | 
         
            +
                        for i in range(num_steps):
         
     | 
| 73 | 
         
            +
                            t = torch.cat([timesteps[:, i]]*2) if do_classifier_free_guidance else timesteps[:, i]
         
     | 
| 74 | 
         
            +
                            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
         
     | 
| 75 | 
         
            +
                            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                            noise_pred = unet(
         
     | 
| 78 | 
         
            +
                                latent_model_input,
         
     | 
| 79 | 
         
            +
                                t,
         
     | 
| 80 | 
         
            +
                                encoder_hidden_states=prompt_embeds,
         
     | 
| 81 | 
         
            +
                                timestep_cond=timestep_cond,
         
     | 
| 82 | 
         
            +
                                return_dict=False,
         
     | 
| 83 | 
         
            +
                            )[0]
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                            if do_classifier_free_guidance:
         
     | 
| 86 | 
         
            +
                                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
         
     | 
| 87 | 
         
            +
                                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                            ##### STEP: compute the previous noisy sample x_t -> x_t-1
         
     | 
| 90 | 
         
            +
                            batch_timesteps = timesteps[:, i].cpu()
         
     | 
| 91 | 
         
            +
                            prev_timestep = batch_timesteps - timestep_interval
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                            alpha_prod_t = self.scheduler.alphas_cumprod[batch_timesteps]
         
     | 
| 94 | 
         
            +
                            alpha_prod_t_prev = torch.zeros_like(alpha_prod_t)
         
     | 
| 95 | 
         
            +
                            for ib in range(prev_timestep.shape[0]): 
         
     | 
| 96 | 
         
            +
                                alpha_prod_t_prev[ib] = self.scheduler.alphas_cumprod[prev_timestep[ib]] if prev_timestep[ib] >= 0 else self.scheduler.final_alpha_cumprod
         
     | 
| 97 | 
         
            +
                            beta_prod_t = 1 - alpha_prod_t
         
     | 
| 98 | 
         
            +
                            
         
     | 
| 99 | 
         
            +
                            alpha_prod_t = alpha_prod_t.to(device=latents.device, dtype=latents.dtype)
         
     | 
| 100 | 
         
            +
                            alpha_prod_t_prev = alpha_prod_t_prev.to(device=latents.device, dtype=latents.dtype)
         
     | 
| 101 | 
         
            +
                            beta_prod_t = beta_prod_t.to(device=latents.device, dtype=latents.dtype)
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                            if self.scheduler.config.prediction_type == "epsilon":
         
     | 
| 104 | 
         
            +
                                pred_original_sample = (latents - beta_prod_t[:,None,None,None] ** (0.5) * noise_pred) / alpha_prod_t[:, None,None,None] ** (0.5)
         
     | 
| 105 | 
         
            +
                                pred_epsilon = noise_pred
         
     | 
| 106 | 
         
            +
                            elif self.scheduler.config.prediction_type == "v_prediction":
         
     | 
| 107 | 
         
            +
                                pred_original_sample = (alpha_prod_t[:,None,None,None]**0.5) * latents - (beta_prod_t[:,None,None,None]**0.5) * noise_pred
         
     | 
| 108 | 
         
            +
                                pred_epsilon = (alpha_prod_t[:,None,None,None]**0.5) * noise_pred + (beta_prod_t[:,None,None,None]**0.5) * latents
         
     | 
| 109 | 
         
            +
                            else:
         
     | 
| 110 | 
         
            +
                                raise NotImplementedError
         
     | 
| 111 | 
         
            +
                                
         
     | 
| 112 | 
         
            +
                            pred_sample_direction = (1 - alpha_prod_t_prev[:,None,None,None]) ** (0.5) * pred_epsilon
         
     | 
| 113 | 
         
            +
                            latents = alpha_prod_t_prev[:,None,None,None] ** (0.5) * pred_original_sample + pred_sample_direction
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                        
         
     | 
| 116 | 
         
            +
                    return latents
         
     | 
| 117 | 
         
            +
                
         
     | 
| 118 | 
         
            +
                
         
     | 
| 119 | 
         
            +
                
         
     | 
| 120 | 
         
            +
                
         
     | 
    	
        src/scheduler_perflow.py
    ADDED
    
    | 
         @@ -0,0 +1,343 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
         
     | 
| 16 | 
         
            +
            # and https://github.com/hojonathanho/diffusion
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            import math
         
     | 
| 19 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 20 | 
         
            +
            from typing import List, Optional, Tuple, Union
         
     | 
| 21 | 
         
            +
            import numpy as np
         
     | 
| 22 | 
         
            +
            import torch
         
     | 
| 23 | 
         
            +
            from diffusers.configuration_utils import ConfigMixin, register_to_config
         
     | 
| 24 | 
         
            +
            from diffusers.utils import BaseOutput
         
     | 
| 25 | 
         
            +
            from diffusers.utils.torch_utils import randn_tensor
         
     | 
| 26 | 
         
            +
            from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            class Time_Windows():
         
     | 
| 30 | 
         
            +
                def __init__(self, t_initial=1, t_terminal=0, num_windows=4, precision=1./1000) -> None:
         
     | 
| 31 | 
         
            +
                    assert t_terminal < t_initial
         
     | 
| 32 | 
         
            +
                    time_windows = [ 1.*i/num_windows for i in range(1, num_windows+1)][::-1]
         
     | 
| 33 | 
         
            +
                    
         
     | 
| 34 | 
         
            +
                    self.window_starts = time_windows                      # [1.0, 0.75, 0.5, 0.25]
         
     | 
| 35 | 
         
            +
                    self.window_ends = time_windows[1:] + [t_terminal]     # [0.75, 0.5, 0.25, 0]
         
     | 
| 36 | 
         
            +
                    self.precision = precision
         
     | 
| 37 | 
         
            +
                
         
     | 
| 38 | 
         
            +
                def get_window(self, tp):
         
     | 
| 39 | 
         
            +
                    idx = 0
         
     | 
| 40 | 
         
            +
                    # robust to numerical error; e.g, (0.6+1/10000) belongs to [0.6, 0.3)
         
     | 
| 41 | 
         
            +
                    while (tp-0.1*self.precision) <= self.window_ends[idx]: 
         
     | 
| 42 | 
         
            +
                        idx += 1
         
     | 
| 43 | 
         
            +
                    return self.window_starts[idx], self.window_ends[idx]
         
     | 
| 44 | 
         
            +
                
         
     | 
| 45 | 
         
            +
                def lookup_window(self, timepoint):
         
     | 
| 46 | 
         
            +
                    if timepoint.dim() == 0:
         
     | 
| 47 | 
         
            +
                        t_start, t_end = self.get_window(timepoint)
         
     | 
| 48 | 
         
            +
                        t_start = torch.ones_like(timepoint) * t_start
         
     | 
| 49 | 
         
            +
                        t_end = torch.ones_like(timepoint) * t_end
         
     | 
| 50 | 
         
            +
                    else:
         
     | 
| 51 | 
         
            +
                        t_start = torch.zeros_like(timepoint)
         
     | 
| 52 | 
         
            +
                        t_end = torch.zeros_like(timepoint)
         
     | 
| 53 | 
         
            +
                        bsz = timepoint.shape[0]
         
     | 
| 54 | 
         
            +
                        for i in range(bsz):
         
     | 
| 55 | 
         
            +
                            tp = timepoint[i]
         
     | 
| 56 | 
         
            +
                            ts, te = self.get_window(tp)
         
     | 
| 57 | 
         
            +
                            t_start[i] = ts
         
     | 
| 58 | 
         
            +
                            t_end[i] = te
         
     | 
| 59 | 
         
            +
                    return t_start, t_end
         
     | 
| 60 | 
         
            +
                
         
     | 
| 61 | 
         
            +
                
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            @dataclass
         
     | 
| 64 | 
         
            +
            # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
         
     | 
| 65 | 
         
            +
            class PeRFlowSchedulerOutput(BaseOutput):
         
     | 
| 66 | 
         
            +
                """
         
     | 
| 67 | 
         
            +
                Output class for the scheduler's `step` function output.
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                Args:
         
     | 
| 70 | 
         
            +
                    prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
         
     | 
| 71 | 
         
            +
                        Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
         
     | 
| 72 | 
         
            +
                        denoising loop.
         
     | 
| 73 | 
         
            +
                    pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
         
     | 
| 74 | 
         
            +
                        The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
         
     | 
| 75 | 
         
            +
                        `pred_original_sample` can be used to preview progress or for guidance.
         
     | 
| 76 | 
         
            +
                """
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                prev_sample: torch.FloatTensor
         
     | 
| 79 | 
         
            +
                pred_original_sample: Optional[torch.FloatTensor] = None
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
            # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
         
     | 
| 83 | 
         
            +
            def betas_for_alpha_bar(
         
     | 
| 84 | 
         
            +
                num_diffusion_timesteps,
         
     | 
| 85 | 
         
            +
                max_beta=0.999,
         
     | 
| 86 | 
         
            +
                alpha_transform_type="cosine",
         
     | 
| 87 | 
         
            +
            ):
         
     | 
| 88 | 
         
            +
                """
         
     | 
| 89 | 
         
            +
                Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
         
     | 
| 90 | 
         
            +
                (1-beta) over time from t = [0,1].
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
         
     | 
| 93 | 
         
            +
                to that part of the diffusion process.
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                Args:
         
     | 
| 97 | 
         
            +
                    num_diffusion_timesteps (`int`): the number of betas to produce.
         
     | 
| 98 | 
         
            +
                    max_beta (`float`): the maximum beta to use; use values lower than 1 to
         
     | 
| 99 | 
         
            +
                                 prevent singularities.
         
     | 
| 100 | 
         
            +
                    alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
         
     | 
| 101 | 
         
            +
                                 Choose from `cosine` or `exp`
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                Returns:
         
     | 
| 104 | 
         
            +
                    betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
         
     | 
| 105 | 
         
            +
                """
         
     | 
| 106 | 
         
            +
                if alpha_transform_type == "cosine":
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                    def alpha_bar_fn(t):
         
     | 
| 109 | 
         
            +
                        return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                elif alpha_transform_type == "exp":
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                    def alpha_bar_fn(t):
         
     | 
| 114 | 
         
            +
                        return math.exp(t * -12.0)
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                else:
         
     | 
| 117 | 
         
            +
                    raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                betas = []
         
     | 
| 120 | 
         
            +
                for i in range(num_diffusion_timesteps):
         
     | 
| 121 | 
         
            +
                    t1 = i / num_diffusion_timesteps
         
     | 
| 122 | 
         
            +
                    t2 = (i + 1) / num_diffusion_timesteps
         
     | 
| 123 | 
         
            +
                    betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
         
     | 
| 124 | 
         
            +
                return torch.tensor(betas, dtype=torch.float32)
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
            class PeRFlowScheduler(SchedulerMixin, ConfigMixin):
         
     | 
| 129 | 
         
            +
                """
         
     | 
| 130 | 
         
            +
                `ReFlowScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
         
     | 
| 131 | 
         
            +
                non-Markovian guidance.
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
         
     | 
| 134 | 
         
            +
                methods the library implements for all schedulers such as loading and saving.
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                Args:
         
     | 
| 137 | 
         
            +
                    num_train_timesteps (`int`, defaults to 1000):
         
     | 
| 138 | 
         
            +
                        The number of diffusion steps to train the model.
         
     | 
| 139 | 
         
            +
                    beta_start (`float`, defaults to 0.0001):
         
     | 
| 140 | 
         
            +
                        The starting `beta` value of inference.
         
     | 
| 141 | 
         
            +
                    beta_end (`float`, defaults to 0.02):
         
     | 
| 142 | 
         
            +
                        The final `beta` value.
         
     | 
| 143 | 
         
            +
                    beta_schedule (`str`, defaults to `"linear"`):
         
     | 
| 144 | 
         
            +
                        The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
         
     | 
| 145 | 
         
            +
                        `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
         
     | 
| 146 | 
         
            +
                    trained_betas (`np.ndarray`, *optional*):
         
     | 
| 147 | 
         
            +
                        Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
         
     | 
| 148 | 
         
            +
                    set_alpha_to_one (`bool`, defaults to `True`):
         
     | 
| 149 | 
         
            +
                        Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
         
     | 
| 150 | 
         
            +
                        there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
         
     | 
| 151 | 
         
            +
                        otherwise it uses the alpha value at step 0.
         
     | 
| 152 | 
         
            +
                    prediction_type (`str`, defaults to `epsilon`, *optional*)
         
     | 
| 153 | 
         
            +
                """
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
                _compatibles = [e.name for e in KarrasDiffusionSchedulers]
         
     | 
| 156 | 
         
            +
                order = 1
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                @register_to_config
         
     | 
| 159 | 
         
            +
                def __init__(
         
     | 
| 160 | 
         
            +
                    self,
         
     | 
| 161 | 
         
            +
                    num_train_timesteps: int = 1000,
         
     | 
| 162 | 
         
            +
                    beta_start: float = 0.00085,
         
     | 
| 163 | 
         
            +
                    beta_end: float = 0.012,
         
     | 
| 164 | 
         
            +
                    beta_schedule: str = "scaled_linear",
         
     | 
| 165 | 
         
            +
                    trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
         
     | 
| 166 | 
         
            +
                    set_alpha_to_one: bool = False,
         
     | 
| 167 | 
         
            +
                    prediction_type: str = "epsilon",
         
     | 
| 168 | 
         
            +
                    t_noise: float = 1,
         
     | 
| 169 | 
         
            +
                    t_clean: float = 0,
         
     | 
| 170 | 
         
            +
                    num_time_windows = 4,
         
     | 
| 171 | 
         
            +
                ):
         
     | 
| 172 | 
         
            +
                    if trained_betas is not None:
         
     | 
| 173 | 
         
            +
                        self.betas = torch.tensor(trained_betas, dtype=torch.float32)
         
     | 
| 174 | 
         
            +
                    elif beta_schedule == "linear":
         
     | 
| 175 | 
         
            +
                        self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
         
     | 
| 176 | 
         
            +
                    elif beta_schedule == "scaled_linear":
         
     | 
| 177 | 
         
            +
                        # this schedule is very specific to the latent diffusion model.
         
     | 
| 178 | 
         
            +
                        self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
         
     | 
| 179 | 
         
            +
                    elif beta_schedule == "squaredcos_cap_v2":
         
     | 
| 180 | 
         
            +
                        # Glide cosine schedule
         
     | 
| 181 | 
         
            +
                        self.betas = betas_for_alpha_bar(num_train_timesteps)
         
     | 
| 182 | 
         
            +
                    else:
         
     | 
| 183 | 
         
            +
                        raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                    self.alphas = 1.0 - self.betas
         
     | 
| 186 | 
         
            +
                    self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                    # At every step in ddim, we are looking into the previous alphas_cumprod
         
     | 
| 189 | 
         
            +
                    # For the final step, there is no previous alphas_cumprod because we are already at 0
         
     | 
| 190 | 
         
            +
                    # `set_alpha_to_one` decides whether we set this parameter simply to one or
         
     | 
| 191 | 
         
            +
                    # whether we use the final alpha of the "non-previous" one.
         
     | 
| 192 | 
         
            +
                    self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
         
     | 
| 193 | 
         
            +
                    
         
     | 
| 194 | 
         
            +
                    # # standard deviation of the initial noise distribution
         
     | 
| 195 | 
         
            +
                    self.init_noise_sigma = 1.0
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                    self.time_windows = Time_Windows(t_initial=t_noise, t_terminal=t_clean, 
         
     | 
| 198 | 
         
            +
                                                    num_windows=num_time_windows,
         
     | 
| 199 | 
         
            +
                                                    precision=1./num_train_timesteps)
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
         
     | 
| 202 | 
         
            +
                    """
         
     | 
| 203 | 
         
            +
                    Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
         
     | 
| 204 | 
         
            +
                    current timestep.
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                    Args:
         
     | 
| 207 | 
         
            +
                        sample (`torch.FloatTensor`):
         
     | 
| 208 | 
         
            +
                            The input sample.
         
     | 
| 209 | 
         
            +
                        timestep (`int`, *optional*):
         
     | 
| 210 | 
         
            +
                            The current timestep in the diffusion chain.
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                    Returns:
         
     | 
| 213 | 
         
            +
                        `torch.FloatTensor`:
         
     | 
| 214 | 
         
            +
                            A scaled input sample.
         
     | 
| 215 | 
         
            +
                    """
         
     | 
| 216 | 
         
            +
                    return sample
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
         
     | 
| 220 | 
         
            +
                    """
         
     | 
| 221 | 
         
            +
                    Sets the discrete timesteps used for the diffusion chain (to be run before inference).
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
                    Args:
         
     | 
| 224 | 
         
            +
                        num_inference_steps (`int`):
         
     | 
| 225 | 
         
            +
                            The number of diffusion steps used when generating samples with a pre-trained model.
         
     | 
| 226 | 
         
            +
                    """
         
     | 
| 227 | 
         
            +
                    if num_inference_steps < self.config.num_time_windows:
         
     | 
| 228 | 
         
            +
                        num_inference_steps = self.config.num_time_windows
         
     | 
| 229 | 
         
            +
                        print(f"### We recommend a num_inference_steps not less than num_time_windows. It's set as {self.config.num_time_windows}.")
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                    timesteps = []
         
     | 
| 232 | 
         
            +
                    for i in range(self.config.num_time_windows):
         
     | 
| 233 | 
         
            +
                        if i < num_inference_steps%self.config.num_time_windows:
         
     | 
| 234 | 
         
            +
                            num_steps_cur_win = num_inference_steps//self.config.num_time_windows+1
         
     | 
| 235 | 
         
            +
                        else:
         
     | 
| 236 | 
         
            +
                            num_steps_cur_win = num_inference_steps//self.config.num_time_windows
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                        t_s = self.time_windows.window_starts[i]
         
     | 
| 239 | 
         
            +
                        t_e = self.time_windows.window_ends[i]
         
     | 
| 240 | 
         
            +
                        timesteps_cur_win = np.linspace(t_s, t_e, num=num_steps_cur_win, endpoint=False)
         
     | 
| 241 | 
         
            +
                        timesteps.append(timesteps_cur_win)
         
     | 
| 242 | 
         
            +
                                    
         
     | 
| 243 | 
         
            +
                    timesteps = np.concatenate(timesteps)
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
                    self.timesteps = torch.from_numpy(
         
     | 
| 246 | 
         
            +
                        (timesteps*self.config.num_train_timesteps).astype(np.int64)
         
     | 
| 247 | 
         
            +
                    ).to(device)
         
     | 
| 248 | 
         
            +
                        
         
     | 
| 249 | 
         
            +
                def get_window_alpha(self, timestep):
         
     | 
| 250 | 
         
            +
                    time_windows = self.time_windows
         
     | 
| 251 | 
         
            +
                    num_train_timesteps = self.config.num_train_timesteps
         
     | 
| 252 | 
         
            +
                    
         
     | 
| 253 | 
         
            +
                    t_win_start, t_win_end = time_windows.lookup_window(timestep / num_train_timesteps)
         
     | 
| 254 | 
         
            +
                    t_win_len = t_win_end - t_win_start
         
     | 
| 255 | 
         
            +
                    t_interval = timestep / num_train_timesteps - t_win_start # NOTE: negative value
         
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
                    idx_start = (t_win_start*num_train_timesteps - 1 ).long()
         
     | 
| 258 | 
         
            +
                    idx_end = torch.clamp( (t_win_end*num_train_timesteps - 1 ).long(), min=0)
         
     | 
| 259 | 
         
            +
                    alpha_cumprod_s_e = self.alphas_cumprod[idx_start] / self.alphas_cumprod[idx_end]        
         
     | 
| 260 | 
         
            +
                    gamma_s_e = alpha_cumprod_s_e ** 0.5
         
     | 
| 261 | 
         
            +
                    
         
     | 
| 262 | 
         
            +
                    return t_win_start, t_win_end, t_win_len, t_interval, gamma_s_e
         
     | 
| 263 | 
         
            +
                    
         
     | 
| 264 | 
         
            +
                def step(
         
     | 
| 265 | 
         
            +
                    self,
         
     | 
| 266 | 
         
            +
                    model_output: torch.FloatTensor,
         
     | 
| 267 | 
         
            +
                    timestep: int,
         
     | 
| 268 | 
         
            +
                    sample: torch.FloatTensor,
         
     | 
| 269 | 
         
            +
                    return_dict: bool = True,
         
     | 
| 270 | 
         
            +
                ) -> Union[PeRFlowSchedulerOutput, Tuple]:
         
     | 
| 271 | 
         
            +
                    """
         
     | 
| 272 | 
         
            +
                    Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
         
     | 
| 273 | 
         
            +
                    process from the learned model outputs (most often the predicted noise).
         
     | 
| 274 | 
         
            +
             
     | 
| 275 | 
         
            +
                    Args:
         
     | 
| 276 | 
         
            +
                        model_output (`torch.FloatTensor`):
         
     | 
| 277 | 
         
            +
                            The direct output from learned diffusion model.
         
     | 
| 278 | 
         
            +
                        timestep (`float`):
         
     | 
| 279 | 
         
            +
                            The current discrete timestep in the diffusion chain.
         
     | 
| 280 | 
         
            +
                        sample (`torch.FloatTensor`):
         
     | 
| 281 | 
         
            +
                            A current instance of a sample created by the diffusion process.
         
     | 
| 282 | 
         
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         
     | 
| 283 | 
         
            +
                            Whether or not to return a [`~schedulers.scheduling_ddim.PeRFlowSchedulerOutput`] or `tuple`.
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
                    Returns:
         
     | 
| 286 | 
         
            +
                        [`~schedulers.scheduling_utils.PeRFlowSchedulerOutput`] or `tuple`:
         
     | 
| 287 | 
         
            +
                            If return_dict is `True`, [`~schedulers.scheduling_ddim.PeRFlowSchedulerOutput`] is returned, otherwise a
         
     | 
| 288 | 
         
            +
                            tuple is returned where the first element is the sample tensor.
         
     | 
| 289 | 
         
            +
                    """
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                    if self.config.prediction_type == "epsilon":
         
     | 
| 292 | 
         
            +
                        pred_epsilon = model_output
         
     | 
| 293 | 
         
            +
                        t_win_start, t_win_end, t_win_len, t_interval, gamma_s_e = self.get_window_alpha(timestep)
         
     | 
| 294 | 
         
            +
                        pred_sample_end = ( sample - (1-t_interval/t_win_len) * ((1-gamma_s_e**2)**0.5) * pred_epsilon ) \
         
     | 
| 295 | 
         
            +
                            / ( gamma_s_e + t_interval / t_win_len * (1-gamma_s_e) )
         
     | 
| 296 | 
         
            +
                        pred_velocity = (pred_sample_end - sample) / (t_win_end - (t_win_start + t_interval))
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
                    elif self.config.prediction_type == "velocity":
         
     | 
| 299 | 
         
            +
                        pred_velocity = model_output
         
     | 
| 300 | 
         
            +
                    else:
         
     | 
| 301 | 
         
            +
                        raise ValueError(
         
     | 
| 302 | 
         
            +
                            f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `velocity`."
         
     | 
| 303 | 
         
            +
                        )
         
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
                    # get dt
         
     | 
| 306 | 
         
            +
                    idx = torch.argwhere(torch.where(self.timesteps==timestep, 1,0))
         
     | 
| 307 | 
         
            +
                    prev_step = self.timesteps[idx+1] if (idx+1)<len(self.timesteps) else 0
         
     | 
| 308 | 
         
            +
                    dt = (prev_step - timestep) / self.config.num_train_timesteps
         
     | 
| 309 | 
         
            +
                    dt = dt.to(sample.device, sample.dtype)
         
     | 
| 310 | 
         
            +
                       
         
     | 
| 311 | 
         
            +
                    prev_sample = sample + dt * pred_velocity
         
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
                    if not return_dict:
         
     | 
| 314 | 
         
            +
                        return (prev_sample,)
         
     | 
| 315 | 
         
            +
                    return PeRFlowSchedulerOutput(prev_sample=prev_sample, pred_original_sample=None)
         
     | 
| 316 | 
         
            +
             
     | 
| 317 | 
         
            +
             
     | 
| 318 | 
         
            +
                # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
         
     | 
| 319 | 
         
            +
                def add_noise(
         
     | 
| 320 | 
         
            +
                    self,
         
     | 
| 321 | 
         
            +
                    original_samples: torch.FloatTensor,
         
     | 
| 322 | 
         
            +
                    noise: torch.FloatTensor,
         
     | 
| 323 | 
         
            +
                    timesteps: torch.IntTensor,
         
     | 
| 324 | 
         
            +
                ) -> torch.FloatTensor:
         
     | 
| 325 | 
         
            +
                    # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
         
     | 
| 326 | 
         
            +
                    alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
         
     | 
| 327 | 
         
            +
                    timesteps = timesteps.to(original_samples.device) - 1   # indexing from 0
         
     | 
| 328 | 
         
            +
             
     | 
| 329 | 
         
            +
                    sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
         
     | 
| 330 | 
         
            +
                    sqrt_alpha_prod = sqrt_alpha_prod.flatten()
         
     | 
| 331 | 
         
            +
                    while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
         
     | 
| 332 | 
         
            +
                        sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
         
     | 
| 333 | 
         
            +
             
     | 
| 334 | 
         
            +
                    sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
         
     | 
| 335 | 
         
            +
                    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
         
     | 
| 336 | 
         
            +
                    while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
         
     | 
| 337 | 
         
            +
                        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
         
     | 
| 338 | 
         
            +
             
     | 
| 339 | 
         
            +
                    noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
         
     | 
| 340 | 
         
            +
                    return noisy_samples
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
                def __len__(self):
         
     | 
| 343 | 
         
            +
                    return self.config.num_train_timesteps
         
     | 
    	
        src/utils_perflow.py
    ADDED
    
    | 
         @@ -0,0 +1,77 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            from collections import OrderedDict
         
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            from safetensors import safe_open
         
     | 
| 5 | 
         
            +
            from safetensors.torch import save_file
         
     | 
| 6 | 
         
            +
            from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
         
     | 
| 7 | 
         
            +
            from diffusers.pipelines.stable_diffusion.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint, convert_ldm_clip_checkpoint
         
     | 
| 8 | 
         
            +
                
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            def merge_delta_weights_into_unet(pipe, delta_weights):
         
     | 
| 11 | 
         
            +
                unet_weights = pipe.unet.state_dict()
         
     | 
| 12 | 
         
            +
                assert unet_weights.keys() == delta_weights.keys()
         
     | 
| 13 | 
         
            +
                for key in delta_weights.keys():
         
     | 
| 14 | 
         
            +
                    dtype = unet_weights[key].dtype
         
     | 
| 15 | 
         
            +
                    unet_weights[key] = unet_weights[key].to(dtype=delta_weights[key].dtype) + delta_weights[key].to(device=unet_weights[key].device)
         
     | 
| 16 | 
         
            +
                    unet_weights[key] = unet_weights[key].to(dtype)
         
     | 
| 17 | 
         
            +
                pipe.unet.load_state_dict(unet_weights, strict=True)
         
     | 
| 18 | 
         
            +
                return pipe
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            def load_delta_weights_into_unet(
         
     | 
| 22 | 
         
            +
                pipe, 
         
     | 
| 23 | 
         
            +
                model_path = "hsyan/piecewise-rectified-flow-v0-1", 
         
     | 
| 24 | 
         
            +
                base_path = "runwayml/stable-diffusion-v1-5",
         
     | 
| 25 | 
         
            +
            ):
         
     | 
| 26 | 
         
            +
                ## load delta_weights
         
     | 
| 27 | 
         
            +
                if os.path.exists(os.path.join(model_path, "delta_weights.safetensors")):
         
     | 
| 28 | 
         
            +
                    print("### delta_weights exists, loading...")
         
     | 
| 29 | 
         
            +
                    delta_weights = OrderedDict()
         
     | 
| 30 | 
         
            +
                    with safe_open(os.path.join(model_path, "delta_weights.safetensors"), framework="pt", device="cpu") as f:
         
     | 
| 31 | 
         
            +
                        for key in f.keys():
         
     | 
| 32 | 
         
            +
                            delta_weights[key] = f.get_tensor(key)
         
     | 
| 33 | 
         
            +
                            
         
     | 
| 34 | 
         
            +
                elif os.path.exists(os.path.join(model_path, "diffusion_pytorch_model.safetensors")):
         
     | 
| 35 | 
         
            +
                    print("### merged_weights exists, loading...")
         
     | 
| 36 | 
         
            +
                    merged_weights = OrderedDict()
         
     | 
| 37 | 
         
            +
                    with safe_open(os.path.join(model_path, "diffusion_pytorch_model.safetensors"), framework="pt", device="cpu") as f:
         
     | 
| 38 | 
         
            +
                        for key in f.keys():
         
     | 
| 39 | 
         
            +
                            merged_weights[key] = f.get_tensor(key)
         
     | 
| 40 | 
         
            +
                            
         
     | 
| 41 | 
         
            +
                    base_weights = StableDiffusionPipeline.from_pretrained(
         
     | 
| 42 | 
         
            +
                        base_path, torch_dtype=torch.float16, safety_checker=None).unet.state_dict()
         
     | 
| 43 | 
         
            +
                    assert base_weights.keys() == merged_weights.keys()
         
     | 
| 44 | 
         
            +
                    
         
     | 
| 45 | 
         
            +
                    delta_weights = OrderedDict()
         
     | 
| 46 | 
         
            +
                    for key in merged_weights.keys():
         
     | 
| 47 | 
         
            +
                        delta_weights[key] = merged_weights[key] - base_weights[key].to(device=merged_weights[key].device, dtype=merged_weights[key].dtype)
         
     | 
| 48 | 
         
            +
                    
         
     | 
| 49 | 
         
            +
                    print("### saving delta_weights...")
         
     | 
| 50 | 
         
            +
                    save_file(delta_weights, os.path.join(model_path, "delta_weights.safetensors"))
         
     | 
| 51 | 
         
            +
                    
         
     | 
| 52 | 
         
            +
                else:
         
     | 
| 53 | 
         
            +
                    raise ValueError(f"{model_path} does not contain delta weights or merged weights")
         
     | 
| 54 | 
         
            +
                    
         
     | 
| 55 | 
         
            +
                ## merge delta_weights to the target pipeline
         
     | 
| 56 | 
         
            +
                pipe = merge_delta_weights_into_unet(pipe, delta_weights)
         
     | 
| 57 | 
         
            +
                return pipe
         
     | 
| 58 | 
         
            +
                
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
            def load_dreambooth_into_pipeline(pipe, sd_dreambooth):
         
     | 
| 63 | 
         
            +
                assert sd_dreambooth.endswith(".safetensors")
         
     | 
| 64 | 
         
            +
                state_dict = {}
         
     | 
| 65 | 
         
            +
                with safe_open(sd_dreambooth, framework="pt", device="cpu") as f:
         
     | 
| 66 | 
         
            +
                    for key in f.keys():
         
     | 
| 67 | 
         
            +
                        state_dict[key] = f.get_tensor(key)
         
     | 
| 68 | 
         
            +
                
         
     | 
| 69 | 
         
            +
                unet_config = {} # unet, line 449 in convert_ldm_unet_checkpoint
         
     | 
| 70 | 
         
            +
                for key in pipe.unet.config.keys():
         
     | 
| 71 | 
         
            +
                    if key != 'num_class_embeds':
         
     | 
| 72 | 
         
            +
                        unet_config[key] = pipe.unet.config[key]
         
     | 
| 73 | 
         
            +
                        
         
     | 
| 74 | 
         
            +
                pipe.unet.load_state_dict(convert_ldm_unet_checkpoint(state_dict, unet_config), strict=False)
         
     | 
| 75 | 
         
            +
                pipe.vae.load_state_dict(convert_ldm_vae_checkpoint(state_dict, pipe.vae.config))
         
     | 
| 76 | 
         
            +
                pipe.text_encoder = convert_ldm_clip_checkpoint(state_dict, text_encoder=pipe.text_encoder)
         
     | 
| 77 | 
         
            +
                return pipe
         
     | 
    	
        test.yaml
    ADDED
    
    | 
         @@ -0,0 +1,10 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            name: test
         
     | 
| 2 | 
         
            +
            channels:
         
     | 
| 3 | 
         
            +
              - pytorch
         
     | 
| 4 | 
         
            +
              - nvidia
         
     | 
| 5 | 
         
            +
              - defaults
         
     | 
| 6 | 
         
            +
              - conda-forge
         
     | 
| 7 | 
         
            +
            dependencies:
         
     | 
| 8 | 
         
            +
              - python=3.10.12
         
     | 
| 9 | 
         
            +
              - pip=23.2.1
         
     | 
| 10 | 
         
            +
              - cudatoolkit=11.7
         
     | 
    	
        tsr/__pycache__/system.cpython-310.pyc
    ADDED
    
    | 
         Binary file (5.19 kB). View file 
     | 
| 
         | 
    	
        tsr/__pycache__/system.cpython-38.pyc
    ADDED
    
    | 
         Binary file (5.07 kB). View file 
     | 
| 
         | 
    	
        tsr/__pycache__/utils.cpython-310.pyc
    ADDED
    
    | 
         Binary file (13.6 kB). View file 
     | 
| 
         | 
    	
        tsr/__pycache__/utils.cpython-38.pyc
    ADDED
    
    | 
         Binary file (13.5 kB). View file 
     | 
| 
         | 
    	
        tsr/models/__pycache__/isosurface.cpython-310.pyc
    ADDED
    
    | 
         Binary file (2.27 kB). View file 
     | 
| 
         | 
    	
        tsr/models/__pycache__/isosurface.cpython-38.pyc
    ADDED
    
    | 
         Binary file (2.23 kB). View file 
     | 
| 
         | 
    	
        tsr/models/__pycache__/nerf_renderer.cpython-310.pyc
    ADDED
    
    | 
         Binary file (5.32 kB). View file 
     | 
| 
         | 
    	
        tsr/models/__pycache__/nerf_renderer.cpython-38.pyc
    ADDED
    
    | 
         Binary file (5.31 kB). View file 
     | 
| 
         | 
    	
        tsr/models/__pycache__/network_utils.cpython-310.pyc
    ADDED
    
    | 
         Binary file (3.44 kB). View file 
     | 
| 
         | 
    	
        tsr/models/__pycache__/network_utils.cpython-38.pyc
    ADDED
    
    | 
         Binary file (3.39 kB). View file 
     | 
| 
         | 
    	
        tsr/models/isosurface.py
    ADDED
    
    | 
         @@ -0,0 +1,52 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Callable, Optional, Tuple
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import numpy as np
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn as nn
         
     | 
| 6 | 
         
            +
            from torchmcubes import marching_cubes
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            class IsosurfaceHelper(nn.Module):
         
     | 
| 10 | 
         
            +
                points_range: Tuple[float, float] = (0, 1)
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                @property
         
     | 
| 13 | 
         
            +
                def grid_vertices(self) -> torch.FloatTensor:
         
     | 
| 14 | 
         
            +
                    raise NotImplementedError
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            class MarchingCubeHelper(IsosurfaceHelper):
         
     | 
| 18 | 
         
            +
                def __init__(self, resolution: int) -> None:
         
     | 
| 19 | 
         
            +
                    super().__init__()
         
     | 
| 20 | 
         
            +
                    self.resolution = resolution
         
     | 
| 21 | 
         
            +
                    self.mc_func: Callable = marching_cubes
         
     | 
| 22 | 
         
            +
                    self._grid_vertices: Optional[torch.FloatTensor] = None
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                @property
         
     | 
| 25 | 
         
            +
                def grid_vertices(self) -> torch.FloatTensor:
         
     | 
| 26 | 
         
            +
                    if self._grid_vertices is None:
         
     | 
| 27 | 
         
            +
                        # keep the vertices on CPU so that we can support very large resolution
         
     | 
| 28 | 
         
            +
                        x, y, z = (
         
     | 
| 29 | 
         
            +
                            torch.linspace(*self.points_range, self.resolution),
         
     | 
| 30 | 
         
            +
                            torch.linspace(*self.points_range, self.resolution),
         
     | 
| 31 | 
         
            +
                            torch.linspace(*self.points_range, self.resolution),
         
     | 
| 32 | 
         
            +
                        )
         
     | 
| 33 | 
         
            +
                        x, y, z = torch.meshgrid(x, y, z, indexing="ij")
         
     | 
| 34 | 
         
            +
                        verts = torch.cat(
         
     | 
| 35 | 
         
            +
                            [x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1
         
     | 
| 36 | 
         
            +
                        ).reshape(-1, 3)
         
     | 
| 37 | 
         
            +
                        self._grid_vertices = verts
         
     | 
| 38 | 
         
            +
                    return self._grid_vertices
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                def forward(
         
     | 
| 41 | 
         
            +
                    self,
         
     | 
| 42 | 
         
            +
                    level: torch.FloatTensor,
         
     | 
| 43 | 
         
            +
                ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
         
     | 
| 44 | 
         
            +
                    level = -level.view(self.resolution, self.resolution, self.resolution)
         
     | 
| 45 | 
         
            +
                    try:
         
     | 
| 46 | 
         
            +
                        v_pos, t_pos_idx = self.mc_func(level.detach(), 0.0)
         
     | 
| 47 | 
         
            +
                    except AttributeError:
         
     | 
| 48 | 
         
            +
                        print("torchmcubes was not compiled with CUDA support, use CPU version instead.")
         
     | 
| 49 | 
         
            +
                        v_pos, t_pos_idx = self.mc_func(level.detach().cpu(), 0.0)
         
     | 
| 50 | 
         
            +
                    v_pos = v_pos[..., [2, 1, 0]]
         
     | 
| 51 | 
         
            +
                    v_pos = v_pos / (self.resolution - 1.0)
         
     | 
| 52 | 
         
            +
                    return v_pos.to(level.device), t_pos_idx.to(level.device)
         
     | 
    	
        tsr/models/nerf_renderer.py
    ADDED
    
    | 
         @@ -0,0 +1,180 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 2 | 
         
            +
            from typing import Dict
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 6 | 
         
            +
            from einops import rearrange, reduce
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from ..utils import (
         
     | 
| 9 | 
         
            +
                BaseModule,
         
     | 
| 10 | 
         
            +
                chunk_batch,
         
     | 
| 11 | 
         
            +
                get_activation,
         
     | 
| 12 | 
         
            +
                rays_intersect_bbox,
         
     | 
| 13 | 
         
            +
                scale_tensor,
         
     | 
| 14 | 
         
            +
            )
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            class TriplaneNeRFRenderer(BaseModule):
         
     | 
| 18 | 
         
            +
                @dataclass
         
     | 
| 19 | 
         
            +
                class Config(BaseModule.Config):
         
     | 
| 20 | 
         
            +
                    radius: float
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                    feature_reduction: str = "concat"
         
     | 
| 23 | 
         
            +
                    density_activation: str = "trunc_exp"
         
     | 
| 24 | 
         
            +
                    density_bias: float = -1.0
         
     | 
| 25 | 
         
            +
                    color_activation: str = "sigmoid"
         
     | 
| 26 | 
         
            +
                    num_samples_per_ray: int = 128
         
     | 
| 27 | 
         
            +
                    randomized: bool = False
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                cfg: Config
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                def configure(self) -> None:
         
     | 
| 32 | 
         
            +
                    assert self.cfg.feature_reduction in ["concat", "mean"]
         
     | 
| 33 | 
         
            +
                    self.chunk_size = 0
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                def set_chunk_size(self, chunk_size: int):
         
     | 
| 36 | 
         
            +
                    assert (
         
     | 
| 37 | 
         
            +
                        chunk_size >= 0
         
     | 
| 38 | 
         
            +
                    ), "chunk_size must be a non-negative integer (0 for no chunking)."
         
     | 
| 39 | 
         
            +
                    self.chunk_size = chunk_size
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                def query_triplane(
         
     | 
| 42 | 
         
            +
                    self,
         
     | 
| 43 | 
         
            +
                    decoder: torch.nn.Module,
         
     | 
| 44 | 
         
            +
                    positions: torch.Tensor,
         
     | 
| 45 | 
         
            +
                    triplane: torch.Tensor,
         
     | 
| 46 | 
         
            +
                ) -> Dict[str, torch.Tensor]:
         
     | 
| 47 | 
         
            +
                    input_shape = positions.shape[:-1]
         
     | 
| 48 | 
         
            +
                    positions = positions.view(-1, 3)
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                    # positions in (-radius, radius)
         
     | 
| 51 | 
         
            +
                    # normalized to (-1, 1) for grid sample
         
     | 
| 52 | 
         
            +
                    positions = scale_tensor(
         
     | 
| 53 | 
         
            +
                        positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
         
     | 
| 54 | 
         
            +
                    )
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                    def _query_chunk(x):
         
     | 
| 57 | 
         
            +
                        indices2D: torch.Tensor = torch.stack(
         
     | 
| 58 | 
         
            +
                            (x[..., [0, 1]], x[..., [0, 2]], x[..., [1, 2]]),
         
     | 
| 59 | 
         
            +
                            dim=-3,
         
     | 
| 60 | 
         
            +
                        )
         
     | 
| 61 | 
         
            +
                        out: torch.Tensor = F.grid_sample(
         
     | 
| 62 | 
         
            +
                            rearrange(triplane, "Np Cp Hp Wp -> Np Cp Hp Wp", Np=3),
         
     | 
| 63 | 
         
            +
                            rearrange(indices2D, "Np N Nd -> Np () N Nd", Np=3),
         
     | 
| 64 | 
         
            +
                            align_corners=False,
         
     | 
| 65 | 
         
            +
                            mode="bilinear",
         
     | 
| 66 | 
         
            +
                        )
         
     | 
| 67 | 
         
            +
                        if self.cfg.feature_reduction == "concat":
         
     | 
| 68 | 
         
            +
                            out = rearrange(out, "Np Cp () N -> N (Np Cp)", Np=3)
         
     | 
| 69 | 
         
            +
                        elif self.cfg.feature_reduction == "mean":
         
     | 
| 70 | 
         
            +
                            out = reduce(out, "Np Cp () N -> N Cp", Np=3, reduction="mean")
         
     | 
| 71 | 
         
            +
                        else:
         
     | 
| 72 | 
         
            +
                            raise NotImplementedError
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                        net_out: Dict[str, torch.Tensor] = decoder(out)
         
     | 
| 75 | 
         
            +
                        return net_out
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                    if self.chunk_size > 0:
         
     | 
| 78 | 
         
            +
                        net_out = chunk_batch(_query_chunk, self.chunk_size, positions)
         
     | 
| 79 | 
         
            +
                    else:
         
     | 
| 80 | 
         
            +
                        net_out = _query_chunk(positions)
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                    net_out["density_act"] = get_activation(self.cfg.density_activation)(
         
     | 
| 83 | 
         
            +
                        net_out["density"] + self.cfg.density_bias
         
     | 
| 84 | 
         
            +
                    )
         
     | 
| 85 | 
         
            +
                    net_out["color"] = get_activation(self.cfg.color_activation)(
         
     | 
| 86 | 
         
            +
                        net_out["features"]
         
     | 
| 87 | 
         
            +
                    )
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    net_out = {k: v.view(*input_shape, -1) for k, v in net_out.items()}
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                    return net_out
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                def _forward(
         
     | 
| 94 | 
         
            +
                    self,
         
     | 
| 95 | 
         
            +
                    decoder: torch.nn.Module,
         
     | 
| 96 | 
         
            +
                    triplane: torch.Tensor,
         
     | 
| 97 | 
         
            +
                    rays_o: torch.Tensor,
         
     | 
| 98 | 
         
            +
                    rays_d: torch.Tensor,
         
     | 
| 99 | 
         
            +
                    **kwargs,
         
     | 
| 100 | 
         
            +
                ):
         
     | 
| 101 | 
         
            +
                    rays_shape = rays_o.shape[:-1]
         
     | 
| 102 | 
         
            +
                    rays_o = rays_o.view(-1, 3)
         
     | 
| 103 | 
         
            +
                    rays_d = rays_d.view(-1, 3)
         
     | 
| 104 | 
         
            +
                    n_rays = rays_o.shape[0]
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                    t_near, t_far, rays_valid = rays_intersect_bbox(rays_o, rays_d, self.cfg.radius)
         
     | 
| 107 | 
         
            +
                    t_near, t_far = t_near[rays_valid], t_far[rays_valid]
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                    t_vals = torch.linspace(
         
     | 
| 110 | 
         
            +
                        0, 1, self.cfg.num_samples_per_ray + 1, device=triplane.device
         
     | 
| 111 | 
         
            +
                    )
         
     | 
| 112 | 
         
            +
                    t_mid = (t_vals[:-1] + t_vals[1:]) / 2.0
         
     | 
| 113 | 
         
            +
                    z_vals = t_near * (1 - t_mid[None]) + t_far * t_mid[None]  # (N_rays, N_samples)
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                    xyz = (
         
     | 
| 116 | 
         
            +
                        rays_o[:, None, :] + z_vals[..., None] * rays_d[..., None, :]
         
     | 
| 117 | 
         
            +
                    )  # (N_rays, N_sample, 3)
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                    mlp_out = self.query_triplane(
         
     | 
| 120 | 
         
            +
                        decoder=decoder,
         
     | 
| 121 | 
         
            +
                        positions=xyz,
         
     | 
| 122 | 
         
            +
                        triplane=triplane,
         
     | 
| 123 | 
         
            +
                    )
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    eps = 1e-10
         
     | 
| 126 | 
         
            +
                    # deltas = z_vals[:, 1:] - z_vals[:, :-1] # (N_rays, N_samples)
         
     | 
| 127 | 
         
            +
                    deltas = t_vals[1:] - t_vals[:-1]  # (N_rays, N_samples)
         
     | 
| 128 | 
         
            +
                    alpha = 1 - torch.exp(
         
     | 
| 129 | 
         
            +
                        -deltas * mlp_out["density_act"][..., 0]
         
     | 
| 130 | 
         
            +
                    )  # (N_rays, N_samples)
         
     | 
| 131 | 
         
            +
                    accum_prod = torch.cat(
         
     | 
| 132 | 
         
            +
                        [
         
     | 
| 133 | 
         
            +
                            torch.ones_like(alpha[:, :1]),
         
     | 
| 134 | 
         
            +
                            torch.cumprod(1 - alpha[:, :-1] + eps, dim=-1),
         
     | 
| 135 | 
         
            +
                        ],
         
     | 
| 136 | 
         
            +
                        dim=-1,
         
     | 
| 137 | 
         
            +
                    )
         
     | 
| 138 | 
         
            +
                    weights = alpha * accum_prod  # (N_rays, N_samples)
         
     | 
| 139 | 
         
            +
                    comp_rgb_ = (weights[..., None] * mlp_out["color"]).sum(dim=-2)  # (N_rays, 3)
         
     | 
| 140 | 
         
            +
                    opacity_ = weights.sum(dim=-1)  # (N_rays)
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                    comp_rgb = torch.zeros(
         
     | 
| 143 | 
         
            +
                        n_rays, 3, dtype=comp_rgb_.dtype, device=comp_rgb_.device
         
     | 
| 144 | 
         
            +
                    )
         
     | 
| 145 | 
         
            +
                    opacity = torch.zeros(n_rays, dtype=opacity_.dtype, device=opacity_.device)
         
     | 
| 146 | 
         
            +
                    comp_rgb[rays_valid] = comp_rgb_
         
     | 
| 147 | 
         
            +
                    opacity[rays_valid] = opacity_
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                    comp_rgb += 1 - opacity[..., None]
         
     | 
| 150 | 
         
            +
                    comp_rgb = comp_rgb.view(*rays_shape, 3)
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                    return comp_rgb
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                def forward(
         
     | 
| 155 | 
         
            +
                    self,
         
     | 
| 156 | 
         
            +
                    decoder: torch.nn.Module,
         
     | 
| 157 | 
         
            +
                    triplane: torch.Tensor,
         
     | 
| 158 | 
         
            +
                    rays_o: torch.Tensor,
         
     | 
| 159 | 
         
            +
                    rays_d: torch.Tensor,
         
     | 
| 160 | 
         
            +
                ) -> Dict[str, torch.Tensor]:
         
     | 
| 161 | 
         
            +
                    if triplane.ndim == 4:
         
     | 
| 162 | 
         
            +
                        comp_rgb = self._forward(decoder, triplane, rays_o, rays_d)
         
     | 
| 163 | 
         
            +
                    else:
         
     | 
| 164 | 
         
            +
                        comp_rgb = torch.stack(
         
     | 
| 165 | 
         
            +
                            [
         
     | 
| 166 | 
         
            +
                                self._forward(decoder, triplane[i], rays_o[i], rays_d[i])
         
     | 
| 167 | 
         
            +
                                for i in range(triplane.shape[0])
         
     | 
| 168 | 
         
            +
                            ],
         
     | 
| 169 | 
         
            +
                            dim=0,
         
     | 
| 170 | 
         
            +
                        )
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                    return comp_rgb
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                def train(self, mode=True):
         
     | 
| 175 | 
         
            +
                    self.randomized = mode and self.cfg.randomized
         
     | 
| 176 | 
         
            +
                    return super().train(mode=mode)
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
                def eval(self):
         
     | 
| 179 | 
         
            +
                    self.randomized = False
         
     | 
| 180 | 
         
            +
                    return super().eval()
         
     | 
    	
        tsr/models/network_utils.py
    ADDED
    
    | 
         @@ -0,0 +1,124 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 2 | 
         
            +
            from typing import Optional
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn as nn
         
     | 
| 6 | 
         
            +
            from einops import rearrange
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from ..utils import BaseModule
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            class TriplaneUpsampleNetwork(BaseModule):
         
     | 
| 12 | 
         
            +
                @dataclass
         
     | 
| 13 | 
         
            +
                class Config(BaseModule.Config):
         
     | 
| 14 | 
         
            +
                    in_channels: int
         
     | 
| 15 | 
         
            +
                    out_channels: int
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                cfg: Config
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                def configure(self) -> None:
         
     | 
| 20 | 
         
            +
                    self.upsample = nn.ConvTranspose2d(
         
     | 
| 21 | 
         
            +
                        self.cfg.in_channels, self.cfg.out_channels, kernel_size=2, stride=2
         
     | 
| 22 | 
         
            +
                    )
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                def forward(self, triplanes: torch.Tensor) -> torch.Tensor:
         
     | 
| 25 | 
         
            +
                    triplanes_up = rearrange(
         
     | 
| 26 | 
         
            +
                        self.upsample(
         
     | 
| 27 | 
         
            +
                            rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
         
     | 
| 28 | 
         
            +
                        ),
         
     | 
| 29 | 
         
            +
                        "(B Np) Co Hp Wp -> B Np Co Hp Wp",
         
     | 
| 30 | 
         
            +
                        Np=3,
         
     | 
| 31 | 
         
            +
                    )
         
     | 
| 32 | 
         
            +
                    return triplanes_up
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            class NeRFMLP(BaseModule):
         
     | 
| 36 | 
         
            +
                @dataclass
         
     | 
| 37 | 
         
            +
                class Config(BaseModule.Config):
         
     | 
| 38 | 
         
            +
                    in_channels: int
         
     | 
| 39 | 
         
            +
                    n_neurons: int
         
     | 
| 40 | 
         
            +
                    n_hidden_layers: int
         
     | 
| 41 | 
         
            +
                    activation: str = "relu"
         
     | 
| 42 | 
         
            +
                    bias: bool = True
         
     | 
| 43 | 
         
            +
                    weight_init: Optional[str] = "kaiming_uniform"
         
     | 
| 44 | 
         
            +
                    bias_init: Optional[str] = None
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                cfg: Config
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                def configure(self) -> None:
         
     | 
| 49 | 
         
            +
                    layers = [
         
     | 
| 50 | 
         
            +
                        self.make_linear(
         
     | 
| 51 | 
         
            +
                            self.cfg.in_channels,
         
     | 
| 52 | 
         
            +
                            self.cfg.n_neurons,
         
     | 
| 53 | 
         
            +
                            bias=self.cfg.bias,
         
     | 
| 54 | 
         
            +
                            weight_init=self.cfg.weight_init,
         
     | 
| 55 | 
         
            +
                            bias_init=self.cfg.bias_init,
         
     | 
| 56 | 
         
            +
                        ),
         
     | 
| 57 | 
         
            +
                        self.make_activation(self.cfg.activation),
         
     | 
| 58 | 
         
            +
                    ]
         
     | 
| 59 | 
         
            +
                    for i in range(self.cfg.n_hidden_layers - 1):
         
     | 
| 60 | 
         
            +
                        layers += [
         
     | 
| 61 | 
         
            +
                            self.make_linear(
         
     | 
| 62 | 
         
            +
                                self.cfg.n_neurons,
         
     | 
| 63 | 
         
            +
                                self.cfg.n_neurons,
         
     | 
| 64 | 
         
            +
                                bias=self.cfg.bias,
         
     | 
| 65 | 
         
            +
                                weight_init=self.cfg.weight_init,
         
     | 
| 66 | 
         
            +
                                bias_init=self.cfg.bias_init,
         
     | 
| 67 | 
         
            +
                            ),
         
     | 
| 68 | 
         
            +
                            self.make_activation(self.cfg.activation),
         
     | 
| 69 | 
         
            +
                        ]
         
     | 
| 70 | 
         
            +
                    layers += [
         
     | 
| 71 | 
         
            +
                        self.make_linear(
         
     | 
| 72 | 
         
            +
                            self.cfg.n_neurons,
         
     | 
| 73 | 
         
            +
                            4,  # density 1 + features 3
         
     | 
| 74 | 
         
            +
                            bias=self.cfg.bias,
         
     | 
| 75 | 
         
            +
                            weight_init=self.cfg.weight_init,
         
     | 
| 76 | 
         
            +
                            bias_init=self.cfg.bias_init,
         
     | 
| 77 | 
         
            +
                        )
         
     | 
| 78 | 
         
            +
                    ]
         
     | 
| 79 | 
         
            +
                    self.layers = nn.Sequential(*layers)
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                def make_linear(
         
     | 
| 82 | 
         
            +
                    self,
         
     | 
| 83 | 
         
            +
                    dim_in,
         
     | 
| 84 | 
         
            +
                    dim_out,
         
     | 
| 85 | 
         
            +
                    bias=True,
         
     | 
| 86 | 
         
            +
                    weight_init=None,
         
     | 
| 87 | 
         
            +
                    bias_init=None,
         
     | 
| 88 | 
         
            +
                ):
         
     | 
| 89 | 
         
            +
                    layer = nn.Linear(dim_in, dim_out, bias=bias)
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                    if weight_init is None:
         
     | 
| 92 | 
         
            +
                        pass
         
     | 
| 93 | 
         
            +
                    elif weight_init == "kaiming_uniform":
         
     | 
| 94 | 
         
            +
                        torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity="relu")
         
     | 
| 95 | 
         
            +
                    else:
         
     | 
| 96 | 
         
            +
                        raise NotImplementedError
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                    if bias:
         
     | 
| 99 | 
         
            +
                        if bias_init is None:
         
     | 
| 100 | 
         
            +
                            pass
         
     | 
| 101 | 
         
            +
                        elif bias_init == "zero":
         
     | 
| 102 | 
         
            +
                            torch.nn.init.zeros_(layer.bias)
         
     | 
| 103 | 
         
            +
                        else:
         
     | 
| 104 | 
         
            +
                            raise NotImplementedError
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                    return layer
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                def make_activation(self, activation):
         
     | 
| 109 | 
         
            +
                    if activation == "relu":
         
     | 
| 110 | 
         
            +
                        return nn.ReLU(inplace=True)
         
     | 
| 111 | 
         
            +
                    elif activation == "silu":
         
     | 
| 112 | 
         
            +
                        return nn.SiLU(inplace=True)
         
     | 
| 113 | 
         
            +
                    else:
         
     | 
| 114 | 
         
            +
                        raise NotImplementedError
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                def forward(self, x):
         
     | 
| 117 | 
         
            +
                    inp_shape = x.shape[:-1]
         
     | 
| 118 | 
         
            +
                    x = x.reshape(-1, x.shape[-1])
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                    features = self.layers(x)
         
     | 
| 121 | 
         
            +
                    features = features.reshape(*inp_shape, -1)
         
     | 
| 122 | 
         
            +
                    out = {"density": features[..., 0:1], "features": features[..., 1:4]}
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                    return out
         
     | 
    	
        tsr/models/tokenizers/__pycache__/image.cpython-310.pyc
    ADDED
    
    | 
         Binary file (2.42 kB). View file 
     | 
| 
         | 
    	
        tsr/models/tokenizers/__pycache__/image.cpython-38.pyc
    ADDED
    
    | 
         Binary file (2.39 kB). View file 
     | 
| 
         | 
    	
        tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc
    ADDED
    
    | 
         Binary file (1.79 kB). View file 
     | 
| 
         | 
    	
        tsr/models/tokenizers/__pycache__/triplane.cpython-38.pyc
    ADDED
    
    | 
         Binary file (1.77 kB). View file 
     | 
| 
         | 
    	
        tsr/models/tokenizers/image.py
    ADDED
    
    | 
         @@ -0,0 +1,66 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            import torch.nn as nn
         
     | 
| 5 | 
         
            +
            from einops import rearrange
         
     | 
| 6 | 
         
            +
            from huggingface_hub import hf_hub_download
         
     | 
| 7 | 
         
            +
            from transformers.models.vit.modeling_vit import ViTModel
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from ...utils import BaseModule
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            class DINOSingleImageTokenizer(BaseModule):
         
     | 
| 13 | 
         
            +
                @dataclass
         
     | 
| 14 | 
         
            +
                class Config(BaseModule.Config):
         
     | 
| 15 | 
         
            +
                    pretrained_model_name_or_path: str = "facebook/dino-vitb16"
         
     | 
| 16 | 
         
            +
                    enable_gradient_checkpointing: bool = False
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                cfg: Config
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                def configure(self) -> None:
         
     | 
| 21 | 
         
            +
                    self.model: ViTModel = ViTModel(
         
     | 
| 22 | 
         
            +
                        ViTModel.config_class.from_pretrained(
         
     | 
| 23 | 
         
            +
                            hf_hub_download(
         
     | 
| 24 | 
         
            +
                                repo_id=self.cfg.pretrained_model_name_or_path,
         
     | 
| 25 | 
         
            +
                                filename="config.json",
         
     | 
| 26 | 
         
            +
                            )
         
     | 
| 27 | 
         
            +
                        )
         
     | 
| 28 | 
         
            +
                    )
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                    if self.cfg.enable_gradient_checkpointing:
         
     | 
| 31 | 
         
            +
                        self.model.encoder.gradient_checkpointing = True
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                    self.register_buffer(
         
     | 
| 34 | 
         
            +
                        "image_mean",
         
     | 
| 35 | 
         
            +
                        torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
         
     | 
| 36 | 
         
            +
                        persistent=False,
         
     | 
| 37 | 
         
            +
                    )
         
     | 
| 38 | 
         
            +
                    self.register_buffer(
         
     | 
| 39 | 
         
            +
                        "image_std",
         
     | 
| 40 | 
         
            +
                        torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
         
     | 
| 41 | 
         
            +
                        persistent=False,
         
     | 
| 42 | 
         
            +
                    )
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                def forward(self, images: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
         
     | 
| 45 | 
         
            +
                    packed = False
         
     | 
| 46 | 
         
            +
                    if images.ndim == 4:
         
     | 
| 47 | 
         
            +
                        packed = True
         
     | 
| 48 | 
         
            +
                        images = images.unsqueeze(1)
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                    batch_size, n_input_views = images.shape[:2]
         
     | 
| 51 | 
         
            +
                    images = (images - self.image_mean) / self.image_std
         
     | 
| 52 | 
         
            +
                    out = self.model(
         
     | 
| 53 | 
         
            +
                        rearrange(images, "B N C H W -> (B N) C H W"), interpolate_pos_encoding=True
         
     | 
| 54 | 
         
            +
                    )
         
     | 
| 55 | 
         
            +
                    local_features, global_features = out.last_hidden_state, out.pooler_output
         
     | 
| 56 | 
         
            +
                    local_features = local_features.permute(0, 2, 1)
         
     | 
| 57 | 
         
            +
                    local_features = rearrange(
         
     | 
| 58 | 
         
            +
                        local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
         
     | 
| 59 | 
         
            +
                    )
         
     | 
| 60 | 
         
            +
                    if packed:
         
     | 
| 61 | 
         
            +
                        local_features = local_features.squeeze(1)
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                    return local_features
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                def detokenize(self, *args, **kwargs):
         
     | 
| 66 | 
         
            +
                    raise NotImplementedError
         
     | 
    	
        tsr/models/tokenizers/triplane.py
    ADDED
    
    | 
         @@ -0,0 +1,45 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import math
         
     | 
| 2 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn as nn
         
     | 
| 6 | 
         
            +
            from einops import rearrange, repeat
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from ...utils import BaseModule
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            class Triplane1DTokenizer(BaseModule):
         
     | 
| 12 | 
         
            +
                @dataclass
         
     | 
| 13 | 
         
            +
                class Config(BaseModule.Config):
         
     | 
| 14 | 
         
            +
                    plane_size: int
         
     | 
| 15 | 
         
            +
                    num_channels: int
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                cfg: Config
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                def configure(self) -> None:
         
     | 
| 20 | 
         
            +
                    self.embeddings = nn.Parameter(
         
     | 
| 21 | 
         
            +
                        torch.randn(
         
     | 
| 22 | 
         
            +
                            (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
         
     | 
| 23 | 
         
            +
                            dtype=torch.float32,
         
     | 
| 24 | 
         
            +
                        )
         
     | 
| 25 | 
         
            +
                        * 1
         
     | 
| 26 | 
         
            +
                        / math.sqrt(self.cfg.num_channels)
         
     | 
| 27 | 
         
            +
                    )
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                def forward(self, batch_size: int) -> torch.Tensor:
         
     | 
| 30 | 
         
            +
                    return rearrange(
         
     | 
| 31 | 
         
            +
                        repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
         
     | 
| 32 | 
         
            +
                        "B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
         
     | 
| 33 | 
         
            +
                    )
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                def detokenize(self, tokens: torch.Tensor) -> torch.Tensor:
         
     | 
| 36 | 
         
            +
                    batch_size, Ct, Nt = tokens.shape
         
     | 
| 37 | 
         
            +
                    assert Nt == self.cfg.plane_size**2 * 3
         
     | 
| 38 | 
         
            +
                    assert Ct == self.cfg.num_channels
         
     | 
| 39 | 
         
            +
                    return rearrange(
         
     | 
| 40 | 
         
            +
                        tokens,
         
     | 
| 41 | 
         
            +
                        "B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
         
     | 
| 42 | 
         
            +
                        Np=3,
         
     | 
| 43 | 
         
            +
                        Hp=self.cfg.plane_size,
         
     | 
| 44 | 
         
            +
                        Wp=self.cfg.plane_size,
         
     | 
| 45 | 
         
            +
                    )
         
     | 
    	
        tsr/models/transformer/__pycache__/attention.cpython-310.pyc
    ADDED
    
    | 
         Binary file (15.3 kB). View file 
     | 
| 
         | 
    	
        tsr/models/transformer/__pycache__/attention.cpython-38.pyc
    ADDED
    
    | 
         Binary file (15.2 kB). View file 
     | 
| 
         | 
    	
        tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc
    ADDED
    
    | 
         Binary file (9.65 kB). View file 
     | 
| 
         | 
    	
        tsr/models/transformer/__pycache__/basic_transformer_block.cpython-38.pyc
    ADDED
    
    | 
         Binary file (9.49 kB). View file 
     | 
| 
         | 
    	
        tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc
    ADDED
    
    | 
         Binary file (4.91 kB). View file 
     | 
| 
         | 
    	
        tsr/models/transformer/__pycache__/transformer_1d.cpython-38.pyc
    ADDED
    
    | 
         Binary file (4.85 kB). View file 
     | 
| 
         | 
    	
        tsr/models/transformer/attention.py
    ADDED
    
    | 
         @@ -0,0 +1,653 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            #
         
     | 
| 15 | 
         
            +
            # --------
         
     | 
| 16 | 
         
            +
            #
         
     | 
| 17 | 
         
            +
            # Modified 2024 by the Tripo AI and Stability AI Team.
         
     | 
| 18 | 
         
            +
            #
         
     | 
| 19 | 
         
            +
            # Copyright (c) 2024 Tripo AI & Stability AI
         
     | 
| 20 | 
         
            +
            #
         
     | 
| 21 | 
         
            +
            # Permission is hereby granted, free of charge, to any person obtaining a copy
         
     | 
| 22 | 
         
            +
            # of this software and associated documentation files (the "Software"), to deal
         
     | 
| 23 | 
         
            +
            # in the Software without restriction, including without limitation the rights
         
     | 
| 24 | 
         
            +
            # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         
     | 
| 25 | 
         
            +
            # copies of the Software, and to permit persons to whom the Software is
         
     | 
| 26 | 
         
            +
            # furnished to do so, subject to the following conditions:
         
     | 
| 27 | 
         
            +
            #
         
     | 
| 28 | 
         
            +
            # The above copyright notice and this permission notice shall be included in all
         
     | 
| 29 | 
         
            +
            # copies or substantial portions of the Software.
         
     | 
| 30 | 
         
            +
            #
         
     | 
| 31 | 
         
            +
            # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         
     | 
| 32 | 
         
            +
            # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         
     | 
| 33 | 
         
            +
            # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         
     | 
| 34 | 
         
            +
            # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         
     | 
| 35 | 
         
            +
            # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         
     | 
| 36 | 
         
            +
            # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         
     | 
| 37 | 
         
            +
            # SOFTWARE.
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            from typing import Optional
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            import torch
         
     | 
| 42 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 43 | 
         
            +
            from torch import nn
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            class Attention(nn.Module):
         
     | 
| 47 | 
         
            +
                r"""
         
     | 
| 48 | 
         
            +
                A cross attention layer.
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                Parameters:
         
     | 
| 51 | 
         
            +
                    query_dim (`int`):
         
     | 
| 52 | 
         
            +
                        The number of channels in the query.
         
     | 
| 53 | 
         
            +
                    cross_attention_dim (`int`, *optional*):
         
     | 
| 54 | 
         
            +
                        The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
         
     | 
| 55 | 
         
            +
                    heads (`int`,  *optional*, defaults to 8):
         
     | 
| 56 | 
         
            +
                        The number of heads to use for multi-head attention.
         
     | 
| 57 | 
         
            +
                    dim_head (`int`,  *optional*, defaults to 64):
         
     | 
| 58 | 
         
            +
                        The number of channels in each head.
         
     | 
| 59 | 
         
            +
                    dropout (`float`, *optional*, defaults to 0.0):
         
     | 
| 60 | 
         
            +
                        The dropout probability to use.
         
     | 
| 61 | 
         
            +
                    bias (`bool`, *optional*, defaults to False):
         
     | 
| 62 | 
         
            +
                        Set to `True` for the query, key, and value linear layers to contain a bias parameter.
         
     | 
| 63 | 
         
            +
                    upcast_attention (`bool`, *optional*, defaults to False):
         
     | 
| 64 | 
         
            +
                        Set to `True` to upcast the attention computation to `float32`.
         
     | 
| 65 | 
         
            +
                    upcast_softmax (`bool`, *optional*, defaults to False):
         
     | 
| 66 | 
         
            +
                        Set to `True` to upcast the softmax computation to `float32`.
         
     | 
| 67 | 
         
            +
                    cross_attention_norm (`str`, *optional*, defaults to `None`):
         
     | 
| 68 | 
         
            +
                        The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
         
     | 
| 69 | 
         
            +
                    cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
         
     | 
| 70 | 
         
            +
                        The number of groups to use for the group norm in the cross attention.
         
     | 
| 71 | 
         
            +
                    added_kv_proj_dim (`int`, *optional*, defaults to `None`):
         
     | 
| 72 | 
         
            +
                        The number of channels to use for the added key and value projections. If `None`, no projection is used.
         
     | 
| 73 | 
         
            +
                    norm_num_groups (`int`, *optional*, defaults to `None`):
         
     | 
| 74 | 
         
            +
                        The number of groups to use for the group norm in the attention.
         
     | 
| 75 | 
         
            +
                    spatial_norm_dim (`int`, *optional*, defaults to `None`):
         
     | 
| 76 | 
         
            +
                        The number of channels to use for the spatial normalization.
         
     | 
| 77 | 
         
            +
                    out_bias (`bool`, *optional*, defaults to `True`):
         
     | 
| 78 | 
         
            +
                        Set to `True` to use a bias in the output linear layer.
         
     | 
| 79 | 
         
            +
                    scale_qk (`bool`, *optional*, defaults to `True`):
         
     | 
| 80 | 
         
            +
                        Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
         
     | 
| 81 | 
         
            +
                    only_cross_attention (`bool`, *optional*, defaults to `False`):
         
     | 
| 82 | 
         
            +
                        Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
         
     | 
| 83 | 
         
            +
                        `added_kv_proj_dim` is not `None`.
         
     | 
| 84 | 
         
            +
                    eps (`float`, *optional*, defaults to 1e-5):
         
     | 
| 85 | 
         
            +
                        An additional value added to the denominator in group normalization that is used for numerical stability.
         
     | 
| 86 | 
         
            +
                    rescale_output_factor (`float`, *optional*, defaults to 1.0):
         
     | 
| 87 | 
         
            +
                        A factor to rescale the output by dividing it with this value.
         
     | 
| 88 | 
         
            +
                    residual_connection (`bool`, *optional*, defaults to `False`):
         
     | 
| 89 | 
         
            +
                        Set to `True` to add the residual connection to the output.
         
     | 
| 90 | 
         
            +
                    _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
         
     | 
| 91 | 
         
            +
                        Set to `True` if the attention block is loaded from a deprecated state dict.
         
     | 
| 92 | 
         
            +
                    processor (`AttnProcessor`, *optional*, defaults to `None`):
         
     | 
| 93 | 
         
            +
                        The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
         
     | 
| 94 | 
         
            +
                        `AttnProcessor` otherwise.
         
     | 
| 95 | 
         
            +
                """
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                def __init__(
         
     | 
| 98 | 
         
            +
                    self,
         
     | 
| 99 | 
         
            +
                    query_dim: int,
         
     | 
| 100 | 
         
            +
                    cross_attention_dim: Optional[int] = None,
         
     | 
| 101 | 
         
            +
                    heads: int = 8,
         
     | 
| 102 | 
         
            +
                    dim_head: int = 64,
         
     | 
| 103 | 
         
            +
                    dropout: float = 0.0,
         
     | 
| 104 | 
         
            +
                    bias: bool = False,
         
     | 
| 105 | 
         
            +
                    upcast_attention: bool = False,
         
     | 
| 106 | 
         
            +
                    upcast_softmax: bool = False,
         
     | 
| 107 | 
         
            +
                    cross_attention_norm: Optional[str] = None,
         
     | 
| 108 | 
         
            +
                    cross_attention_norm_num_groups: int = 32,
         
     | 
| 109 | 
         
            +
                    added_kv_proj_dim: Optional[int] = None,
         
     | 
| 110 | 
         
            +
                    norm_num_groups: Optional[int] = None,
         
     | 
| 111 | 
         
            +
                    out_bias: bool = True,
         
     | 
| 112 | 
         
            +
                    scale_qk: bool = True,
         
     | 
| 113 | 
         
            +
                    only_cross_attention: bool = False,
         
     | 
| 114 | 
         
            +
                    eps: float = 1e-5,
         
     | 
| 115 | 
         
            +
                    rescale_output_factor: float = 1.0,
         
     | 
| 116 | 
         
            +
                    residual_connection: bool = False,
         
     | 
| 117 | 
         
            +
                    _from_deprecated_attn_block: bool = False,
         
     | 
| 118 | 
         
            +
                    processor: Optional["AttnProcessor"] = None,
         
     | 
| 119 | 
         
            +
                    out_dim: int = None,
         
     | 
| 120 | 
         
            +
                ):
         
     | 
| 121 | 
         
            +
                    super().__init__()
         
     | 
| 122 | 
         
            +
                    self.inner_dim = out_dim if out_dim is not None else dim_head * heads
         
     | 
| 123 | 
         
            +
                    self.query_dim = query_dim
         
     | 
| 124 | 
         
            +
                    self.cross_attention_dim = (
         
     | 
| 125 | 
         
            +
                        cross_attention_dim if cross_attention_dim is not None else query_dim
         
     | 
| 126 | 
         
            +
                    )
         
     | 
| 127 | 
         
            +
                    self.upcast_attention = upcast_attention
         
     | 
| 128 | 
         
            +
                    self.upcast_softmax = upcast_softmax
         
     | 
| 129 | 
         
            +
                    self.rescale_output_factor = rescale_output_factor
         
     | 
| 130 | 
         
            +
                    self.residual_connection = residual_connection
         
     | 
| 131 | 
         
            +
                    self.dropout = dropout
         
     | 
| 132 | 
         
            +
                    self.fused_projections = False
         
     | 
| 133 | 
         
            +
                    self.out_dim = out_dim if out_dim is not None else query_dim
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                    # we make use of this private variable to know whether this class is loaded
         
     | 
| 136 | 
         
            +
                    # with an deprecated state dict so that we can convert it on the fly
         
     | 
| 137 | 
         
            +
                    self._from_deprecated_attn_block = _from_deprecated_attn_block
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                    self.scale_qk = scale_qk
         
     | 
| 140 | 
         
            +
                    self.scale = dim_head**-0.5 if self.scale_qk else 1.0
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                    self.heads = out_dim // dim_head if out_dim is not None else heads
         
     | 
| 143 | 
         
            +
                    # for slice_size > 0 the attention score computation
         
     | 
| 144 | 
         
            +
                    # is split across the batch axis to save memory
         
     | 
| 145 | 
         
            +
                    # You can set slice_size with `set_attention_slice`
         
     | 
| 146 | 
         
            +
                    self.sliceable_head_dim = heads
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                    self.added_kv_proj_dim = added_kv_proj_dim
         
     | 
| 149 | 
         
            +
                    self.only_cross_attention = only_cross_attention
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                    if self.added_kv_proj_dim is None and self.only_cross_attention:
         
     | 
| 152 | 
         
            +
                        raise ValueError(
         
     | 
| 153 | 
         
            +
                            "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
         
     | 
| 154 | 
         
            +
                        )
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                    if norm_num_groups is not None:
         
     | 
| 157 | 
         
            +
                        self.group_norm = nn.GroupNorm(
         
     | 
| 158 | 
         
            +
                            num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True
         
     | 
| 159 | 
         
            +
                        )
         
     | 
| 160 | 
         
            +
                    else:
         
     | 
| 161 | 
         
            +
                        self.group_norm = None
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                    self.spatial_norm = None
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                    if cross_attention_norm is None:
         
     | 
| 166 | 
         
            +
                        self.norm_cross = None
         
     | 
| 167 | 
         
            +
                    elif cross_attention_norm == "layer_norm":
         
     | 
| 168 | 
         
            +
                        self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
         
     | 
| 169 | 
         
            +
                    elif cross_attention_norm == "group_norm":
         
     | 
| 170 | 
         
            +
                        if self.added_kv_proj_dim is not None:
         
     | 
| 171 | 
         
            +
                            # The given `encoder_hidden_states` are initially of shape
         
     | 
| 172 | 
         
            +
                            # (batch_size, seq_len, added_kv_proj_dim) before being projected
         
     | 
| 173 | 
         
            +
                            # to (batch_size, seq_len, cross_attention_dim). The norm is applied
         
     | 
| 174 | 
         
            +
                            # before the projection, so we need to use `added_kv_proj_dim` as
         
     | 
| 175 | 
         
            +
                            # the number of channels for the group norm.
         
     | 
| 176 | 
         
            +
                            norm_cross_num_channels = added_kv_proj_dim
         
     | 
| 177 | 
         
            +
                        else:
         
     | 
| 178 | 
         
            +
                            norm_cross_num_channels = self.cross_attention_dim
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                        self.norm_cross = nn.GroupNorm(
         
     | 
| 181 | 
         
            +
                            num_channels=norm_cross_num_channels,
         
     | 
| 182 | 
         
            +
                            num_groups=cross_attention_norm_num_groups,
         
     | 
| 183 | 
         
            +
                            eps=1e-5,
         
     | 
| 184 | 
         
            +
                            affine=True,
         
     | 
| 185 | 
         
            +
                        )
         
     | 
| 186 | 
         
            +
                    else:
         
     | 
| 187 | 
         
            +
                        raise ValueError(
         
     | 
| 188 | 
         
            +
                            f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
         
     | 
| 189 | 
         
            +
                        )
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                    linear_cls = nn.Linear
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                    self.linear_cls = linear_cls
         
     | 
| 194 | 
         
            +
                    self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                    if not self.only_cross_attention:
         
     | 
| 197 | 
         
            +
                        # only relevant for the `AddedKVProcessor` classes
         
     | 
| 198 | 
         
            +
                        self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
         
     | 
| 199 | 
         
            +
                        self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
         
     | 
| 200 | 
         
            +
                    else:
         
     | 
| 201 | 
         
            +
                        self.to_k = None
         
     | 
| 202 | 
         
            +
                        self.to_v = None
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                    if self.added_kv_proj_dim is not None:
         
     | 
| 205 | 
         
            +
                        self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
         
     | 
| 206 | 
         
            +
                        self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                    self.to_out = nn.ModuleList([])
         
     | 
| 209 | 
         
            +
                    self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
         
     | 
| 210 | 
         
            +
                    self.to_out.append(nn.Dropout(dropout))
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                    # set attention processor
         
     | 
| 213 | 
         
            +
                    # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
         
     | 
| 214 | 
         
            +
                    # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
         
     | 
| 215 | 
         
            +
                    # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
         
     | 
| 216 | 
         
            +
                    if processor is None:
         
     | 
| 217 | 
         
            +
                        processor = (
         
     | 
| 218 | 
         
            +
                            AttnProcessor2_0()
         
     | 
| 219 | 
         
            +
                            if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
         
     | 
| 220 | 
         
            +
                            else AttnProcessor()
         
     | 
| 221 | 
         
            +
                        )
         
     | 
| 222 | 
         
            +
                    self.set_processor(processor)
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
                def set_processor(self, processor: "AttnProcessor") -> None:
         
     | 
| 225 | 
         
            +
                    self.processor = processor
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
                def forward(
         
     | 
| 228 | 
         
            +
                    self,
         
     | 
| 229 | 
         
            +
                    hidden_states: torch.FloatTensor,
         
     | 
| 230 | 
         
            +
                    encoder_hidden_states: Optional[torch.FloatTensor] = None,
         
     | 
| 231 | 
         
            +
                    attention_mask: Optional[torch.FloatTensor] = None,
         
     | 
| 232 | 
         
            +
                    **cross_attention_kwargs,
         
     | 
| 233 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 234 | 
         
            +
                    r"""
         
     | 
| 235 | 
         
            +
                    The forward method of the `Attention` class.
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                    Args:
         
     | 
| 238 | 
         
            +
                        hidden_states (`torch.Tensor`):
         
     | 
| 239 | 
         
            +
                            The hidden states of the query.
         
     | 
| 240 | 
         
            +
                        encoder_hidden_states (`torch.Tensor`, *optional*):
         
     | 
| 241 | 
         
            +
                            The hidden states of the encoder.
         
     | 
| 242 | 
         
            +
                        attention_mask (`torch.Tensor`, *optional*):
         
     | 
| 243 | 
         
            +
                            The attention mask to use. If `None`, no mask is applied.
         
     | 
| 244 | 
         
            +
                        **cross_attention_kwargs:
         
     | 
| 245 | 
         
            +
                            Additional keyword arguments to pass along to the cross attention.
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                    Returns:
         
     | 
| 248 | 
         
            +
                        `torch.Tensor`: The output of the attention layer.
         
     | 
| 249 | 
         
            +
                    """
         
     | 
| 250 | 
         
            +
                    # The `Attention` class can call different attention processors / attention functions
         
     | 
| 251 | 
         
            +
                    # here we simply pass along all tensors to the selected processor class
         
     | 
| 252 | 
         
            +
                    # For standard processors that are defined here, `**cross_attention_kwargs` is empty
         
     | 
| 253 | 
         
            +
                    return self.processor(
         
     | 
| 254 | 
         
            +
                        self,
         
     | 
| 255 | 
         
            +
                        hidden_states,
         
     | 
| 256 | 
         
            +
                        encoder_hidden_states=encoder_hidden_states,
         
     | 
| 257 | 
         
            +
                        attention_mask=attention_mask,
         
     | 
| 258 | 
         
            +
                        **cross_attention_kwargs,
         
     | 
| 259 | 
         
            +
                    )
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
                def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
         
     | 
| 262 | 
         
            +
                    r"""
         
     | 
| 263 | 
         
            +
                    Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
         
     | 
| 264 | 
         
            +
                    is the number of heads initialized while constructing the `Attention` class.
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                    Args:
         
     | 
| 267 | 
         
            +
                        tensor (`torch.Tensor`): The tensor to reshape.
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                    Returns:
         
     | 
| 270 | 
         
            +
                        `torch.Tensor`: The reshaped tensor.
         
     | 
| 271 | 
         
            +
                    """
         
     | 
| 272 | 
         
            +
                    head_size = self.heads
         
     | 
| 273 | 
         
            +
                    batch_size, seq_len, dim = tensor.shape
         
     | 
| 274 | 
         
            +
                    tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
         
     | 
| 275 | 
         
            +
                    tensor = tensor.permute(0, 2, 1, 3).reshape(
         
     | 
| 276 | 
         
            +
                        batch_size // head_size, seq_len, dim * head_size
         
     | 
| 277 | 
         
            +
                    )
         
     | 
| 278 | 
         
            +
                    return tensor
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
                def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
         
     | 
| 281 | 
         
            +
                    r"""
         
     | 
| 282 | 
         
            +
                    Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
         
     | 
| 283 | 
         
            +
                    the number of heads initialized while constructing the `Attention` class.
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
                    Args:
         
     | 
| 286 | 
         
            +
                        tensor (`torch.Tensor`): The tensor to reshape.
         
     | 
| 287 | 
         
            +
                        out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
         
     | 
| 288 | 
         
            +
                            reshaped to `[batch_size * heads, seq_len, dim // heads]`.
         
     | 
| 289 | 
         
            +
             
     | 
| 290 | 
         
            +
                    Returns:
         
     | 
| 291 | 
         
            +
                        `torch.Tensor`: The reshaped tensor.
         
     | 
| 292 | 
         
            +
                    """
         
     | 
| 293 | 
         
            +
                    head_size = self.heads
         
     | 
| 294 | 
         
            +
                    batch_size, seq_len, dim = tensor.shape
         
     | 
| 295 | 
         
            +
                    tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
         
     | 
| 296 | 
         
            +
                    tensor = tensor.permute(0, 2, 1, 3)
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
                    if out_dim == 3:
         
     | 
| 299 | 
         
            +
                        tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                    return tensor
         
     | 
| 302 | 
         
            +
             
     | 
| 303 | 
         
            +
                def get_attention_scores(
         
     | 
| 304 | 
         
            +
                    self,
         
     | 
| 305 | 
         
            +
                    query: torch.Tensor,
         
     | 
| 306 | 
         
            +
                    key: torch.Tensor,
         
     | 
| 307 | 
         
            +
                    attention_mask: torch.Tensor = None,
         
     | 
| 308 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 309 | 
         
            +
                    r"""
         
     | 
| 310 | 
         
            +
                    Compute the attention scores.
         
     | 
| 311 | 
         
            +
             
     | 
| 312 | 
         
            +
                    Args:
         
     | 
| 313 | 
         
            +
                        query (`torch.Tensor`): The query tensor.
         
     | 
| 314 | 
         
            +
                        key (`torch.Tensor`): The key tensor.
         
     | 
| 315 | 
         
            +
                        attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
         
     | 
| 316 | 
         
            +
             
     | 
| 317 | 
         
            +
                    Returns:
         
     | 
| 318 | 
         
            +
                        `torch.Tensor`: The attention probabilities/scores.
         
     | 
| 319 | 
         
            +
                    """
         
     | 
| 320 | 
         
            +
                    dtype = query.dtype
         
     | 
| 321 | 
         
            +
                    if self.upcast_attention:
         
     | 
| 322 | 
         
            +
                        query = query.float()
         
     | 
| 323 | 
         
            +
                        key = key.float()
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
                    if attention_mask is None:
         
     | 
| 326 | 
         
            +
                        baddbmm_input = torch.empty(
         
     | 
| 327 | 
         
            +
                            query.shape[0],
         
     | 
| 328 | 
         
            +
                            query.shape[1],
         
     | 
| 329 | 
         
            +
                            key.shape[1],
         
     | 
| 330 | 
         
            +
                            dtype=query.dtype,
         
     | 
| 331 | 
         
            +
                            device=query.device,
         
     | 
| 332 | 
         
            +
                        )
         
     | 
| 333 | 
         
            +
                        beta = 0
         
     | 
| 334 | 
         
            +
                    else:
         
     | 
| 335 | 
         
            +
                        baddbmm_input = attention_mask
         
     | 
| 336 | 
         
            +
                        beta = 1
         
     | 
| 337 | 
         
            +
             
     | 
| 338 | 
         
            +
                    attention_scores = torch.baddbmm(
         
     | 
| 339 | 
         
            +
                        baddbmm_input,
         
     | 
| 340 | 
         
            +
                        query,
         
     | 
| 341 | 
         
            +
                        key.transpose(-1, -2),
         
     | 
| 342 | 
         
            +
                        beta=beta,
         
     | 
| 343 | 
         
            +
                        alpha=self.scale,
         
     | 
| 344 | 
         
            +
                    )
         
     | 
| 345 | 
         
            +
                    del baddbmm_input
         
     | 
| 346 | 
         
            +
             
     | 
| 347 | 
         
            +
                    if self.upcast_softmax:
         
     | 
| 348 | 
         
            +
                        attention_scores = attention_scores.float()
         
     | 
| 349 | 
         
            +
             
     | 
| 350 | 
         
            +
                    attention_probs = attention_scores.softmax(dim=-1)
         
     | 
| 351 | 
         
            +
                    del attention_scores
         
     | 
| 352 | 
         
            +
             
     | 
| 353 | 
         
            +
                    attention_probs = attention_probs.to(dtype)
         
     | 
| 354 | 
         
            +
             
     | 
| 355 | 
         
            +
                    return attention_probs
         
     | 
| 356 | 
         
            +
             
     | 
| 357 | 
         
            +
                def prepare_attention_mask(
         
     | 
| 358 | 
         
            +
                    self,
         
     | 
| 359 | 
         
            +
                    attention_mask: torch.Tensor,
         
     | 
| 360 | 
         
            +
                    target_length: int,
         
     | 
| 361 | 
         
            +
                    batch_size: int,
         
     | 
| 362 | 
         
            +
                    out_dim: int = 3,
         
     | 
| 363 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 364 | 
         
            +
                    r"""
         
     | 
| 365 | 
         
            +
                    Prepare the attention mask for the attention computation.
         
     | 
| 366 | 
         
            +
             
     | 
| 367 | 
         
            +
                    Args:
         
     | 
| 368 | 
         
            +
                        attention_mask (`torch.Tensor`):
         
     | 
| 369 | 
         
            +
                            The attention mask to prepare.
         
     | 
| 370 | 
         
            +
                        target_length (`int`):
         
     | 
| 371 | 
         
            +
                            The target length of the attention mask. This is the length of the attention mask after padding.
         
     | 
| 372 | 
         
            +
                        batch_size (`int`):
         
     | 
| 373 | 
         
            +
                            The batch size, which is used to repeat the attention mask.
         
     | 
| 374 | 
         
            +
                        out_dim (`int`, *optional*, defaults to `3`):
         
     | 
| 375 | 
         
            +
                            The output dimension of the attention mask. Can be either `3` or `4`.
         
     | 
| 376 | 
         
            +
             
     | 
| 377 | 
         
            +
                    Returns:
         
     | 
| 378 | 
         
            +
                        `torch.Tensor`: The prepared attention mask.
         
     | 
| 379 | 
         
            +
                    """
         
     | 
| 380 | 
         
            +
                    head_size = self.heads
         
     | 
| 381 | 
         
            +
                    if attention_mask is None:
         
     | 
| 382 | 
         
            +
                        return attention_mask
         
     | 
| 383 | 
         
            +
             
     | 
| 384 | 
         
            +
                    current_length: int = attention_mask.shape[-1]
         
     | 
| 385 | 
         
            +
                    if current_length != target_length:
         
     | 
| 386 | 
         
            +
                        if attention_mask.device.type == "mps":
         
     | 
| 387 | 
         
            +
                            # HACK: MPS: Does not support padding by greater than dimension of input tensor.
         
     | 
| 388 | 
         
            +
                            # Instead, we can manually construct the padding tensor.
         
     | 
| 389 | 
         
            +
                            padding_shape = (
         
     | 
| 390 | 
         
            +
                                attention_mask.shape[0],
         
     | 
| 391 | 
         
            +
                                attention_mask.shape[1],
         
     | 
| 392 | 
         
            +
                                target_length,
         
     | 
| 393 | 
         
            +
                            )
         
     | 
| 394 | 
         
            +
                            padding = torch.zeros(
         
     | 
| 395 | 
         
            +
                                padding_shape,
         
     | 
| 396 | 
         
            +
                                dtype=attention_mask.dtype,
         
     | 
| 397 | 
         
            +
                                device=attention_mask.device,
         
     | 
| 398 | 
         
            +
                            )
         
     | 
| 399 | 
         
            +
                            attention_mask = torch.cat([attention_mask, padding], dim=2)
         
     | 
| 400 | 
         
            +
                        else:
         
     | 
| 401 | 
         
            +
                            # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
         
     | 
| 402 | 
         
            +
                            #       we want to instead pad by (0, remaining_length), where remaining_length is:
         
     | 
| 403 | 
         
            +
                            #       remaining_length: int = target_length - current_length
         
     | 
| 404 | 
         
            +
                            # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
         
     | 
| 405 | 
         
            +
                            attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
         
     | 
| 406 | 
         
            +
             
     | 
| 407 | 
         
            +
                    if out_dim == 3:
         
     | 
| 408 | 
         
            +
                        if attention_mask.shape[0] < batch_size * head_size:
         
     | 
| 409 | 
         
            +
                            attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
         
     | 
| 410 | 
         
            +
                    elif out_dim == 4:
         
     | 
| 411 | 
         
            +
                        attention_mask = attention_mask.unsqueeze(1)
         
     | 
| 412 | 
         
            +
                        attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
         
     | 
| 413 | 
         
            +
             
     | 
| 414 | 
         
            +
                    return attention_mask
         
     | 
| 415 | 
         
            +
             
     | 
| 416 | 
         
            +
                def norm_encoder_hidden_states(
         
     | 
| 417 | 
         
            +
                    self, encoder_hidden_states: torch.Tensor
         
     | 
| 418 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 419 | 
         
            +
                    r"""
         
     | 
| 420 | 
         
            +
                    Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
         
     | 
| 421 | 
         
            +
                    `Attention` class.
         
     | 
| 422 | 
         
            +
             
     | 
| 423 | 
         
            +
                    Args:
         
     | 
| 424 | 
         
            +
                        encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
         
     | 
| 425 | 
         
            +
             
     | 
| 426 | 
         
            +
                    Returns:
         
     | 
| 427 | 
         
            +
                        `torch.Tensor`: The normalized encoder hidden states.
         
     | 
| 428 | 
         
            +
                    """
         
     | 
| 429 | 
         
            +
                    assert (
         
     | 
| 430 | 
         
            +
                        self.norm_cross is not None
         
     | 
| 431 | 
         
            +
                    ), "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
         
     | 
| 432 | 
         
            +
             
     | 
| 433 | 
         
            +
                    if isinstance(self.norm_cross, nn.LayerNorm):
         
     | 
| 434 | 
         
            +
                        encoder_hidden_states = self.norm_cross(encoder_hidden_states)
         
     | 
| 435 | 
         
            +
                    elif isinstance(self.norm_cross, nn.GroupNorm):
         
     | 
| 436 | 
         
            +
                        # Group norm norms along the channels dimension and expects
         
     | 
| 437 | 
         
            +
                        # input to be in the shape of (N, C, *). In this case, we want
         
     | 
| 438 | 
         
            +
                        # to norm along the hidden dimension, so we need to move
         
     | 
| 439 | 
         
            +
                        # (batch_size, sequence_length, hidden_size) ->
         
     | 
| 440 | 
         
            +
                        # (batch_size, hidden_size, sequence_length)
         
     | 
| 441 | 
         
            +
                        encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
         
     | 
| 442 | 
         
            +
                        encoder_hidden_states = self.norm_cross(encoder_hidden_states)
         
     | 
| 443 | 
         
            +
                        encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
         
     | 
| 444 | 
         
            +
                    else:
         
     | 
| 445 | 
         
            +
                        assert False
         
     | 
| 446 | 
         
            +
             
     | 
| 447 | 
         
            +
                    return encoder_hidden_states
         
     | 
| 448 | 
         
            +
             
     | 
| 449 | 
         
            +
                @torch.no_grad()
         
     | 
| 450 | 
         
            +
                def fuse_projections(self, fuse=True):
         
     | 
| 451 | 
         
            +
                    is_cross_attention = self.cross_attention_dim != self.query_dim
         
     | 
| 452 | 
         
            +
                    device = self.to_q.weight.data.device
         
     | 
| 453 | 
         
            +
                    dtype = self.to_q.weight.data.dtype
         
     | 
| 454 | 
         
            +
             
     | 
| 455 | 
         
            +
                    if not is_cross_attention:
         
     | 
| 456 | 
         
            +
                        # fetch weight matrices.
         
     | 
| 457 | 
         
            +
                        concatenated_weights = torch.cat(
         
     | 
| 458 | 
         
            +
                            [self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]
         
     | 
| 459 | 
         
            +
                        )
         
     | 
| 460 | 
         
            +
                        in_features = concatenated_weights.shape[1]
         
     | 
| 461 | 
         
            +
                        out_features = concatenated_weights.shape[0]
         
     | 
| 462 | 
         
            +
             
     | 
| 463 | 
         
            +
                        # create a new single projection layer and copy over the weights.
         
     | 
| 464 | 
         
            +
                        self.to_qkv = self.linear_cls(
         
     | 
| 465 | 
         
            +
                            in_features, out_features, bias=False, device=device, dtype=dtype
         
     | 
| 466 | 
         
            +
                        )
         
     | 
| 467 | 
         
            +
                        self.to_qkv.weight.copy_(concatenated_weights)
         
     | 
| 468 | 
         
            +
             
     | 
| 469 | 
         
            +
                    else:
         
     | 
| 470 | 
         
            +
                        concatenated_weights = torch.cat(
         
     | 
| 471 | 
         
            +
                            [self.to_k.weight.data, self.to_v.weight.data]
         
     | 
| 472 | 
         
            +
                        )
         
     | 
| 473 | 
         
            +
                        in_features = concatenated_weights.shape[1]
         
     | 
| 474 | 
         
            +
                        out_features = concatenated_weights.shape[0]
         
     | 
| 475 | 
         
            +
             
     | 
| 476 | 
         
            +
                        self.to_kv = self.linear_cls(
         
     | 
| 477 | 
         
            +
                            in_features, out_features, bias=False, device=device, dtype=dtype
         
     | 
| 478 | 
         
            +
                        )
         
     | 
| 479 | 
         
            +
                        self.to_kv.weight.copy_(concatenated_weights)
         
     | 
| 480 | 
         
            +
             
     | 
| 481 | 
         
            +
                    self.fused_projections = fuse
         
     | 
| 482 | 
         
            +
             
     | 
| 483 | 
         
            +
             
     | 
| 484 | 
         
            +
            class AttnProcessor:
         
     | 
| 485 | 
         
            +
                r"""
         
     | 
| 486 | 
         
            +
                Default processor for performing attention-related computations.
         
     | 
| 487 | 
         
            +
                """
         
     | 
| 488 | 
         
            +
             
     | 
| 489 | 
         
            +
                def __call__(
         
     | 
| 490 | 
         
            +
                    self,
         
     | 
| 491 | 
         
            +
                    attn: Attention,
         
     | 
| 492 | 
         
            +
                    hidden_states: torch.FloatTensor,
         
     | 
| 493 | 
         
            +
                    encoder_hidden_states: Optional[torch.FloatTensor] = None,
         
     | 
| 494 | 
         
            +
                    attention_mask: Optional[torch.FloatTensor] = None,
         
     | 
| 495 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 496 | 
         
            +
                    residual = hidden_states
         
     | 
| 497 | 
         
            +
             
     | 
| 498 | 
         
            +
                    input_ndim = hidden_states.ndim
         
     | 
| 499 | 
         
            +
             
     | 
| 500 | 
         
            +
                    if input_ndim == 4:
         
     | 
| 501 | 
         
            +
                        batch_size, channel, height, width = hidden_states.shape
         
     | 
| 502 | 
         
            +
                        hidden_states = hidden_states.view(
         
     | 
| 503 | 
         
            +
                            batch_size, channel, height * width
         
     | 
| 504 | 
         
            +
                        ).transpose(1, 2)
         
     | 
| 505 | 
         
            +
             
     | 
| 506 | 
         
            +
                    batch_size, sequence_length, _ = (
         
     | 
| 507 | 
         
            +
                        hidden_states.shape
         
     | 
| 508 | 
         
            +
                        if encoder_hidden_states is None
         
     | 
| 509 | 
         
            +
                        else encoder_hidden_states.shape
         
     | 
| 510 | 
         
            +
                    )
         
     | 
| 511 | 
         
            +
                    attention_mask = attn.prepare_attention_mask(
         
     | 
| 512 | 
         
            +
                        attention_mask, sequence_length, batch_size
         
     | 
| 513 | 
         
            +
                    )
         
     | 
| 514 | 
         
            +
             
     | 
| 515 | 
         
            +
                    if attn.group_norm is not None:
         
     | 
| 516 | 
         
            +
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
         
     | 
| 517 | 
         
            +
                            1, 2
         
     | 
| 518 | 
         
            +
                        )
         
     | 
| 519 | 
         
            +
             
     | 
| 520 | 
         
            +
                    query = attn.to_q(hidden_states)
         
     | 
| 521 | 
         
            +
             
     | 
| 522 | 
         
            +
                    if encoder_hidden_states is None:
         
     | 
| 523 | 
         
            +
                        encoder_hidden_states = hidden_states
         
     | 
| 524 | 
         
            +
                    elif attn.norm_cross:
         
     | 
| 525 | 
         
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(
         
     | 
| 526 | 
         
            +
                            encoder_hidden_states
         
     | 
| 527 | 
         
            +
                        )
         
     | 
| 528 | 
         
            +
             
     | 
| 529 | 
         
            +
                    key = attn.to_k(encoder_hidden_states)
         
     | 
| 530 | 
         
            +
                    value = attn.to_v(encoder_hidden_states)
         
     | 
| 531 | 
         
            +
             
     | 
| 532 | 
         
            +
                    query = attn.head_to_batch_dim(query)
         
     | 
| 533 | 
         
            +
                    key = attn.head_to_batch_dim(key)
         
     | 
| 534 | 
         
            +
                    value = attn.head_to_batch_dim(value)
         
     | 
| 535 | 
         
            +
             
     | 
| 536 | 
         
            +
                    attention_probs = attn.get_attention_scores(query, key, attention_mask)
         
     | 
| 537 | 
         
            +
                    hidden_states = torch.bmm(attention_probs, value)
         
     | 
| 538 | 
         
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         
     | 
| 539 | 
         
            +
             
     | 
| 540 | 
         
            +
                    # linear proj
         
     | 
| 541 | 
         
            +
                    hidden_states = attn.to_out[0](hidden_states)
         
     | 
| 542 | 
         
            +
                    # dropout
         
     | 
| 543 | 
         
            +
                    hidden_states = attn.to_out[1](hidden_states)
         
     | 
| 544 | 
         
            +
             
     | 
| 545 | 
         
            +
                    if input_ndim == 4:
         
     | 
| 546 | 
         
            +
                        hidden_states = hidden_states.transpose(-1, -2).reshape(
         
     | 
| 547 | 
         
            +
                            batch_size, channel, height, width
         
     | 
| 548 | 
         
            +
                        )
         
     | 
| 549 | 
         
            +
             
     | 
| 550 | 
         
            +
                    if attn.residual_connection:
         
     | 
| 551 | 
         
            +
                        hidden_states = hidden_states + residual
         
     | 
| 552 | 
         
            +
             
     | 
| 553 | 
         
            +
                    hidden_states = hidden_states / attn.rescale_output_factor
         
     | 
| 554 | 
         
            +
             
     | 
| 555 | 
         
            +
                    return hidden_states
         
     | 
| 556 | 
         
            +
             
     | 
| 557 | 
         
            +
             
     | 
| 558 | 
         
            +
            class AttnProcessor2_0:
         
     | 
| 559 | 
         
            +
                r"""
         
     | 
| 560 | 
         
            +
                Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
         
     | 
| 561 | 
         
            +
                """
         
     | 
| 562 | 
         
            +
             
     | 
| 563 | 
         
            +
                def __init__(self):
         
     | 
| 564 | 
         
            +
                    if not hasattr(F, "scaled_dot_product_attention"):
         
     | 
| 565 | 
         
            +
                        raise ImportError(
         
     | 
| 566 | 
         
            +
                            "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
         
     | 
| 567 | 
         
            +
                        )
         
     | 
| 568 | 
         
            +
             
     | 
| 569 | 
         
            +
                def __call__(
         
     | 
| 570 | 
         
            +
                    self,
         
     | 
| 571 | 
         
            +
                    attn: Attention,
         
     | 
| 572 | 
         
            +
                    hidden_states: torch.FloatTensor,
         
     | 
| 573 | 
         
            +
                    encoder_hidden_states: Optional[torch.FloatTensor] = None,
         
     | 
| 574 | 
         
            +
                    attention_mask: Optional[torch.FloatTensor] = None,
         
     | 
| 575 | 
         
            +
                ) -> torch.FloatTensor:
         
     | 
| 576 | 
         
            +
                    residual = hidden_states
         
     | 
| 577 | 
         
            +
             
     | 
| 578 | 
         
            +
                    input_ndim = hidden_states.ndim
         
     | 
| 579 | 
         
            +
             
     | 
| 580 | 
         
            +
                    if input_ndim == 4:
         
     | 
| 581 | 
         
            +
                        batch_size, channel, height, width = hidden_states.shape
         
     | 
| 582 | 
         
            +
                        hidden_states = hidden_states.view(
         
     | 
| 583 | 
         
            +
                            batch_size, channel, height * width
         
     | 
| 584 | 
         
            +
                        ).transpose(1, 2)
         
     | 
| 585 | 
         
            +
             
     | 
| 586 | 
         
            +
                    batch_size, sequence_length, _ = (
         
     | 
| 587 | 
         
            +
                        hidden_states.shape
         
     | 
| 588 | 
         
            +
                        if encoder_hidden_states is None
         
     | 
| 589 | 
         
            +
                        else encoder_hidden_states.shape
         
     | 
| 590 | 
         
            +
                    )
         
     | 
| 591 | 
         
            +
             
     | 
| 592 | 
         
            +
                    if attention_mask is not None:
         
     | 
| 593 | 
         
            +
                        attention_mask = attn.prepare_attention_mask(
         
     | 
| 594 | 
         
            +
                            attention_mask, sequence_length, batch_size
         
     | 
| 595 | 
         
            +
                        )
         
     | 
| 596 | 
         
            +
                        # scaled_dot_product_attention expects attention_mask shape to be
         
     | 
| 597 | 
         
            +
                        # (batch, heads, source_length, target_length)
         
     | 
| 598 | 
         
            +
                        attention_mask = attention_mask.view(
         
     | 
| 599 | 
         
            +
                            batch_size, attn.heads, -1, attention_mask.shape[-1]
         
     | 
| 600 | 
         
            +
                        )
         
     | 
| 601 | 
         
            +
             
     | 
| 602 | 
         
            +
                    if attn.group_norm is not None:
         
     | 
| 603 | 
         
            +
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
         
     | 
| 604 | 
         
            +
                            1, 2
         
     | 
| 605 | 
         
            +
                        )
         
     | 
| 606 | 
         
            +
             
     | 
| 607 | 
         
            +
                    query = attn.to_q(hidden_states)
         
     | 
| 608 | 
         
            +
             
     | 
| 609 | 
         
            +
                    if encoder_hidden_states is None:
         
     | 
| 610 | 
         
            +
                        encoder_hidden_states = hidden_states
         
     | 
| 611 | 
         
            +
                    elif attn.norm_cross:
         
     | 
| 612 | 
         
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(
         
     | 
| 613 | 
         
            +
                            encoder_hidden_states
         
     | 
| 614 | 
         
            +
                        )
         
     | 
| 615 | 
         
            +
             
     | 
| 616 | 
         
            +
                    key = attn.to_k(encoder_hidden_states)
         
     | 
| 617 | 
         
            +
                    value = attn.to_v(encoder_hidden_states)
         
     | 
| 618 | 
         
            +
             
     | 
| 619 | 
         
            +
                    inner_dim = key.shape[-1]
         
     | 
| 620 | 
         
            +
                    head_dim = inner_dim // attn.heads
         
     | 
| 621 | 
         
            +
             
     | 
| 622 | 
         
            +
                    query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         
     | 
| 623 | 
         
            +
             
     | 
| 624 | 
         
            +
                    key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         
     | 
| 625 | 
         
            +
                    value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         
     | 
| 626 | 
         
            +
             
     | 
| 627 | 
         
            +
                    # the output of sdp = (batch, num_heads, seq_len, head_dim)
         
     | 
| 628 | 
         
            +
                    # TODO: add support for attn.scale when we move to Torch 2.1
         
     | 
| 629 | 
         
            +
                    hidden_states = F.scaled_dot_product_attention(
         
     | 
| 630 | 
         
            +
                        query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
         
     | 
| 631 | 
         
            +
                    )
         
     | 
| 632 | 
         
            +
             
     | 
| 633 | 
         
            +
                    hidden_states = hidden_states.transpose(1, 2).reshape(
         
     | 
| 634 | 
         
            +
                        batch_size, -1, attn.heads * head_dim
         
     | 
| 635 | 
         
            +
                    )
         
     | 
| 636 | 
         
            +
                    hidden_states = hidden_states.to(query.dtype)
         
     | 
| 637 | 
         
            +
             
     | 
| 638 | 
         
            +
                    # linear proj
         
     | 
| 639 | 
         
            +
                    hidden_states = attn.to_out[0](hidden_states)
         
     | 
| 640 | 
         
            +
                    # dropout
         
     | 
| 641 | 
         
            +
                    hidden_states = attn.to_out[1](hidden_states)
         
     | 
| 642 | 
         
            +
             
     | 
| 643 | 
         
            +
                    if input_ndim == 4:
         
     | 
| 644 | 
         
            +
                        hidden_states = hidden_states.transpose(-1, -2).reshape(
         
     | 
| 645 | 
         
            +
                            batch_size, channel, height, width
         
     | 
| 646 | 
         
            +
                        )
         
     | 
| 647 | 
         
            +
             
     | 
| 648 | 
         
            +
                    if attn.residual_connection:
         
     | 
| 649 | 
         
            +
                        hidden_states = hidden_states + residual
         
     | 
| 650 | 
         
            +
             
     | 
| 651 | 
         
            +
                    hidden_states = hidden_states / attn.rescale_output_factor
         
     | 
| 652 | 
         
            +
             
     | 
| 653 | 
         
            +
                    return hidden_states
         
     | 
    	
        tsr/models/transformer/basic_transformer_block.py
    ADDED
    
    | 
         @@ -0,0 +1,334 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            #
         
     | 
| 15 | 
         
            +
            # --------
         
     | 
| 16 | 
         
            +
            #
         
     | 
| 17 | 
         
            +
            # Modified 2024 by the Tripo AI and Stability AI Team.
         
     | 
| 18 | 
         
            +
            #
         
     | 
| 19 | 
         
            +
            # Copyright (c) 2024 Tripo AI & Stability AI
         
     | 
| 20 | 
         
            +
            #
         
     | 
| 21 | 
         
            +
            # Permission is hereby granted, free of charge, to any person obtaining a copy
         
     | 
| 22 | 
         
            +
            # of this software and associated documentation files (the "Software"), to deal
         
     | 
| 23 | 
         
            +
            # in the Software without restriction, including without limitation the rights
         
     | 
| 24 | 
         
            +
            # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         
     | 
| 25 | 
         
            +
            # copies of the Software, and to permit persons to whom the Software is
         
     | 
| 26 | 
         
            +
            # furnished to do so, subject to the following conditions:
         
     | 
| 27 | 
         
            +
            #
         
     | 
| 28 | 
         
            +
            # The above copyright notice and this permission notice shall be included in all
         
     | 
| 29 | 
         
            +
            # copies or substantial portions of the Software.
         
     | 
| 30 | 
         
            +
            #
         
     | 
| 31 | 
         
            +
            # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         
     | 
| 32 | 
         
            +
            # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         
     | 
| 33 | 
         
            +
            # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         
     | 
| 34 | 
         
            +
            # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         
     | 
| 35 | 
         
            +
            # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         
     | 
| 36 | 
         
            +
            # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         
     | 
| 37 | 
         
            +
            # SOFTWARE.
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            from typing import Optional
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            import torch
         
     | 
| 42 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 43 | 
         
            +
            from torch import nn
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            from .attention import Attention
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            class BasicTransformerBlock(nn.Module):
         
     | 
| 49 | 
         
            +
                r"""
         
     | 
| 50 | 
         
            +
                A basic Transformer block.
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                Parameters:
         
     | 
| 53 | 
         
            +
                    dim (`int`): The number of channels in the input and output.
         
     | 
| 54 | 
         
            +
                    num_attention_heads (`int`): The number of heads to use for multi-head attention.
         
     | 
| 55 | 
         
            +
                    attention_head_dim (`int`): The number of channels in each head.
         
     | 
| 56 | 
         
            +
                    dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
         
     | 
| 57 | 
         
            +
                    cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
         
     | 
| 58 | 
         
            +
                    activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
         
     | 
| 59 | 
         
            +
                    attention_bias (:
         
     | 
| 60 | 
         
            +
                        obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
         
     | 
| 61 | 
         
            +
                    only_cross_attention (`bool`, *optional*):
         
     | 
| 62 | 
         
            +
                        Whether to use only cross-attention layers. In this case two cross attention layers are used.
         
     | 
| 63 | 
         
            +
                    double_self_attention (`bool`, *optional*):
         
     | 
| 64 | 
         
            +
                        Whether to use two self-attention layers. In this case no cross attention layers are used.
         
     | 
| 65 | 
         
            +
                    upcast_attention (`bool`, *optional*):
         
     | 
| 66 | 
         
            +
                        Whether to upcast the attention computation to float32. This is useful for mixed precision training.
         
     | 
| 67 | 
         
            +
                    norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
         
     | 
| 68 | 
         
            +
                        Whether to use learnable elementwise affine parameters for normalization.
         
     | 
| 69 | 
         
            +
                    norm_type (`str`, *optional*, defaults to `"layer_norm"`):
         
     | 
| 70 | 
         
            +
                        The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
         
     | 
| 71 | 
         
            +
                    final_dropout (`bool` *optional*, defaults to False):
         
     | 
| 72 | 
         
            +
                        Whether to apply a final dropout after the last feed-forward layer.
         
     | 
| 73 | 
         
            +
                """
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                def __init__(
         
     | 
| 76 | 
         
            +
                    self,
         
     | 
| 77 | 
         
            +
                    dim: int,
         
     | 
| 78 | 
         
            +
                    num_attention_heads: int,
         
     | 
| 79 | 
         
            +
                    attention_head_dim: int,
         
     | 
| 80 | 
         
            +
                    dropout=0.0,
         
     | 
| 81 | 
         
            +
                    cross_attention_dim: Optional[int] = None,
         
     | 
| 82 | 
         
            +
                    activation_fn: str = "geglu",
         
     | 
| 83 | 
         
            +
                    attention_bias: bool = False,
         
     | 
| 84 | 
         
            +
                    only_cross_attention: bool = False,
         
     | 
| 85 | 
         
            +
                    double_self_attention: bool = False,
         
     | 
| 86 | 
         
            +
                    upcast_attention: bool = False,
         
     | 
| 87 | 
         
            +
                    norm_elementwise_affine: bool = True,
         
     | 
| 88 | 
         
            +
                    norm_type: str = "layer_norm",
         
     | 
| 89 | 
         
            +
                    final_dropout: bool = False,
         
     | 
| 90 | 
         
            +
                ):
         
     | 
| 91 | 
         
            +
                    super().__init__()
         
     | 
| 92 | 
         
            +
                    self.only_cross_attention = only_cross_attention
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    assert norm_type == "layer_norm"
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                    # Define 3 blocks. Each block has its own normalization layer.
         
     | 
| 97 | 
         
            +
                    # 1. Self-Attn
         
     | 
| 98 | 
         
            +
                    self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
         
     | 
| 99 | 
         
            +
                    self.attn1 = Attention(
         
     | 
| 100 | 
         
            +
                        query_dim=dim,
         
     | 
| 101 | 
         
            +
                        heads=num_attention_heads,
         
     | 
| 102 | 
         
            +
                        dim_head=attention_head_dim,
         
     | 
| 103 | 
         
            +
                        dropout=dropout,
         
     | 
| 104 | 
         
            +
                        bias=attention_bias,
         
     | 
| 105 | 
         
            +
                        cross_attention_dim=cross_attention_dim if only_cross_attention else None,
         
     | 
| 106 | 
         
            +
                        upcast_attention=upcast_attention,
         
     | 
| 107 | 
         
            +
                    )
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                    # 2. Cross-Attn
         
     | 
| 110 | 
         
            +
                    if cross_attention_dim is not None or double_self_attention:
         
     | 
| 111 | 
         
            +
                        # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
         
     | 
| 112 | 
         
            +
                        # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
         
     | 
| 113 | 
         
            +
                        # the second cross attention block.
         
     | 
| 114 | 
         
            +
                        self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                        self.attn2 = Attention(
         
     | 
| 117 | 
         
            +
                            query_dim=dim,
         
     | 
| 118 | 
         
            +
                            cross_attention_dim=(
         
     | 
| 119 | 
         
            +
                                cross_attention_dim if not double_self_attention else None
         
     | 
| 120 | 
         
            +
                            ),
         
     | 
| 121 | 
         
            +
                            heads=num_attention_heads,
         
     | 
| 122 | 
         
            +
                            dim_head=attention_head_dim,
         
     | 
| 123 | 
         
            +
                            dropout=dropout,
         
     | 
| 124 | 
         
            +
                            bias=attention_bias,
         
     | 
| 125 | 
         
            +
                            upcast_attention=upcast_attention,
         
     | 
| 126 | 
         
            +
                        )  # is self-attn if encoder_hidden_states is none
         
     | 
| 127 | 
         
            +
                    else:
         
     | 
| 128 | 
         
            +
                        self.norm2 = None
         
     | 
| 129 | 
         
            +
                        self.attn2 = None
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                    # 3. Feed-forward
         
     | 
| 132 | 
         
            +
                    self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
         
     | 
| 133 | 
         
            +
                    self.ff = FeedForward(
         
     | 
| 134 | 
         
            +
                        dim,
         
     | 
| 135 | 
         
            +
                        dropout=dropout,
         
     | 
| 136 | 
         
            +
                        activation_fn=activation_fn,
         
     | 
| 137 | 
         
            +
                        final_dropout=final_dropout,
         
     | 
| 138 | 
         
            +
                    )
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                    # let chunk size default to None
         
     | 
| 141 | 
         
            +
                    self._chunk_size = None
         
     | 
| 142 | 
         
            +
                    self._chunk_dim = 0
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
         
     | 
| 145 | 
         
            +
                    # Sets chunk feed-forward
         
     | 
| 146 | 
         
            +
                    self._chunk_size = chunk_size
         
     | 
| 147 | 
         
            +
                    self._chunk_dim = dim
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                def forward(
         
     | 
| 150 | 
         
            +
                    self,
         
     | 
| 151 | 
         
            +
                    hidden_states: torch.FloatTensor,
         
     | 
| 152 | 
         
            +
                    attention_mask: Optional[torch.FloatTensor] = None,
         
     | 
| 153 | 
         
            +
                    encoder_hidden_states: Optional[torch.FloatTensor] = None,
         
     | 
| 154 | 
         
            +
                    encoder_attention_mask: Optional[torch.FloatTensor] = None,
         
     | 
| 155 | 
         
            +
                ) -> torch.FloatTensor:
         
     | 
| 156 | 
         
            +
                    # Notice that normalization is always applied before the real computation in the following blocks.
         
     | 
| 157 | 
         
            +
                    # 0. Self-Attention
         
     | 
| 158 | 
         
            +
                    norm_hidden_states = self.norm1(hidden_states)
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                    attn_output = self.attn1(
         
     | 
| 161 | 
         
            +
                        norm_hidden_states,
         
     | 
| 162 | 
         
            +
                        encoder_hidden_states=(
         
     | 
| 163 | 
         
            +
                            encoder_hidden_states if self.only_cross_attention else None
         
     | 
| 164 | 
         
            +
                        ),
         
     | 
| 165 | 
         
            +
                        attention_mask=attention_mask,
         
     | 
| 166 | 
         
            +
                    )
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                    hidden_states = attn_output + hidden_states
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                    # 3. Cross-Attention
         
     | 
| 171 | 
         
            +
                    if self.attn2 is not None:
         
     | 
| 172 | 
         
            +
                        norm_hidden_states = self.norm2(hidden_states)
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                        attn_output = self.attn2(
         
     | 
| 175 | 
         
            +
                            norm_hidden_states,
         
     | 
| 176 | 
         
            +
                            encoder_hidden_states=encoder_hidden_states,
         
     | 
| 177 | 
         
            +
                            attention_mask=encoder_attention_mask,
         
     | 
| 178 | 
         
            +
                        )
         
     | 
| 179 | 
         
            +
                        hidden_states = attn_output + hidden_states
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                    # 4. Feed-forward
         
     | 
| 182 | 
         
            +
                    norm_hidden_states = self.norm3(hidden_states)
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                    if self._chunk_size is not None:
         
     | 
| 185 | 
         
            +
                        # "feed_forward_chunk_size" can be used to save memory
         
     | 
| 186 | 
         
            +
                        if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
         
     | 
| 187 | 
         
            +
                            raise ValueError(
         
     | 
| 188 | 
         
            +
                                f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
         
     | 
| 189 | 
         
            +
                            )
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                        num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
         
     | 
| 192 | 
         
            +
                        ff_output = torch.cat(
         
     | 
| 193 | 
         
            +
                            [
         
     | 
| 194 | 
         
            +
                                self.ff(hid_slice)
         
     | 
| 195 | 
         
            +
                                for hid_slice in norm_hidden_states.chunk(
         
     | 
| 196 | 
         
            +
                                    num_chunks, dim=self._chunk_dim
         
     | 
| 197 | 
         
            +
                                )
         
     | 
| 198 | 
         
            +
                            ],
         
     | 
| 199 | 
         
            +
                            dim=self._chunk_dim,
         
     | 
| 200 | 
         
            +
                        )
         
     | 
| 201 | 
         
            +
                    else:
         
     | 
| 202 | 
         
            +
                        ff_output = self.ff(norm_hidden_states)
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                    hidden_states = ff_output + hidden_states
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                    return hidden_states
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
            class FeedForward(nn.Module):
         
     | 
| 210 | 
         
            +
                r"""
         
     | 
| 211 | 
         
            +
                A feed-forward layer.
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
                Parameters:
         
     | 
| 214 | 
         
            +
                    dim (`int`): The number of channels in the input.
         
     | 
| 215 | 
         
            +
                    dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
         
     | 
| 216 | 
         
            +
                    mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
         
     | 
| 217 | 
         
            +
                    dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
         
     | 
| 218 | 
         
            +
                    activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
         
     | 
| 219 | 
         
            +
                    final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
         
     | 
| 220 | 
         
            +
                """
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                def __init__(
         
     | 
| 223 | 
         
            +
                    self,
         
     | 
| 224 | 
         
            +
                    dim: int,
         
     | 
| 225 | 
         
            +
                    dim_out: Optional[int] = None,
         
     | 
| 226 | 
         
            +
                    mult: int = 4,
         
     | 
| 227 | 
         
            +
                    dropout: float = 0.0,
         
     | 
| 228 | 
         
            +
                    activation_fn: str = "geglu",
         
     | 
| 229 | 
         
            +
                    final_dropout: bool = False,
         
     | 
| 230 | 
         
            +
                ):
         
     | 
| 231 | 
         
            +
                    super().__init__()
         
     | 
| 232 | 
         
            +
                    inner_dim = int(dim * mult)
         
     | 
| 233 | 
         
            +
                    dim_out = dim_out if dim_out is not None else dim
         
     | 
| 234 | 
         
            +
                    linear_cls = nn.Linear
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
                    if activation_fn == "gelu":
         
     | 
| 237 | 
         
            +
                        act_fn = GELU(dim, inner_dim)
         
     | 
| 238 | 
         
            +
                    if activation_fn == "gelu-approximate":
         
     | 
| 239 | 
         
            +
                        act_fn = GELU(dim, inner_dim, approximate="tanh")
         
     | 
| 240 | 
         
            +
                    elif activation_fn == "geglu":
         
     | 
| 241 | 
         
            +
                        act_fn = GEGLU(dim, inner_dim)
         
     | 
| 242 | 
         
            +
                    elif activation_fn == "geglu-approximate":
         
     | 
| 243 | 
         
            +
                        act_fn = ApproximateGELU(dim, inner_dim)
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
                    self.net = nn.ModuleList([])
         
     | 
| 246 | 
         
            +
                    # project in
         
     | 
| 247 | 
         
            +
                    self.net.append(act_fn)
         
     | 
| 248 | 
         
            +
                    # project dropout
         
     | 
| 249 | 
         
            +
                    self.net.append(nn.Dropout(dropout))
         
     | 
| 250 | 
         
            +
                    # project out
         
     | 
| 251 | 
         
            +
                    self.net.append(linear_cls(inner_dim, dim_out))
         
     | 
| 252 | 
         
            +
                    # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
         
     | 
| 253 | 
         
            +
                    if final_dropout:
         
     | 
| 254 | 
         
            +
                        self.net.append(nn.Dropout(dropout))
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         
     | 
| 257 | 
         
            +
                    for module in self.net:
         
     | 
| 258 | 
         
            +
                        hidden_states = module(hidden_states)
         
     | 
| 259 | 
         
            +
                    return hidden_states
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
             
     | 
| 262 | 
         
            +
            class GELU(nn.Module):
         
     | 
| 263 | 
         
            +
                r"""
         
     | 
| 264 | 
         
            +
                GELU activation function with tanh approximation support with `approximate="tanh"`.
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                Parameters:
         
     | 
| 267 | 
         
            +
                    dim_in (`int`): The number of channels in the input.
         
     | 
| 268 | 
         
            +
                    dim_out (`int`): The number of channels in the output.
         
     | 
| 269 | 
         
            +
                    approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
         
     | 
| 270 | 
         
            +
                """
         
     | 
| 271 | 
         
            +
             
     | 
| 272 | 
         
            +
                def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
         
     | 
| 273 | 
         
            +
                    super().__init__()
         
     | 
| 274 | 
         
            +
                    self.proj = nn.Linear(dim_in, dim_out)
         
     | 
| 275 | 
         
            +
                    self.approximate = approximate
         
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
                def gelu(self, gate: torch.Tensor) -> torch.Tensor:
         
     | 
| 278 | 
         
            +
                    if gate.device.type != "mps":
         
     | 
| 279 | 
         
            +
                        return F.gelu(gate, approximate=self.approximate)
         
     | 
| 280 | 
         
            +
                    # mps: gelu is not implemented for float16
         
     | 
| 281 | 
         
            +
                    return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(
         
     | 
| 282 | 
         
            +
                        dtype=gate.dtype
         
     | 
| 283 | 
         
            +
                    )
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
                def forward(self, hidden_states):
         
     | 
| 286 | 
         
            +
                    hidden_states = self.proj(hidden_states)
         
     | 
| 287 | 
         
            +
                    hidden_states = self.gelu(hidden_states)
         
     | 
| 288 | 
         
            +
                    return hidden_states
         
     | 
| 289 | 
         
            +
             
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
            class GEGLU(nn.Module):
         
     | 
| 292 | 
         
            +
                r"""
         
     | 
| 293 | 
         
            +
                A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
         
     | 
| 294 | 
         
            +
             
     | 
| 295 | 
         
            +
                Parameters:
         
     | 
| 296 | 
         
            +
                    dim_in (`int`): The number of channels in the input.
         
     | 
| 297 | 
         
            +
                    dim_out (`int`): The number of channels in the output.
         
     | 
| 298 | 
         
            +
                """
         
     | 
| 299 | 
         
            +
             
     | 
| 300 | 
         
            +
                def __init__(self, dim_in: int, dim_out: int):
         
     | 
| 301 | 
         
            +
                    super().__init__()
         
     | 
| 302 | 
         
            +
                    linear_cls = nn.Linear
         
     | 
| 303 | 
         
            +
             
     | 
| 304 | 
         
            +
                    self.proj = linear_cls(dim_in, dim_out * 2)
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
                def gelu(self, gate: torch.Tensor) -> torch.Tensor:
         
     | 
| 307 | 
         
            +
                    if gate.device.type != "mps":
         
     | 
| 308 | 
         
            +
                        return F.gelu(gate)
         
     | 
| 309 | 
         
            +
                    # mps: gelu is not implemented for float16
         
     | 
| 310 | 
         
            +
                    return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
         
     | 
| 311 | 
         
            +
             
     | 
| 312 | 
         
            +
                def forward(self, hidden_states, scale: float = 1.0):
         
     | 
| 313 | 
         
            +
                    args = ()
         
     | 
| 314 | 
         
            +
                    hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
         
     | 
| 315 | 
         
            +
                    return hidden_states * self.gelu(gate)
         
     | 
| 316 | 
         
            +
             
     | 
| 317 | 
         
            +
             
     | 
| 318 | 
         
            +
            class ApproximateGELU(nn.Module):
         
     | 
| 319 | 
         
            +
                r"""
         
     | 
| 320 | 
         
            +
                The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2:
         
     | 
| 321 | 
         
            +
                https://arxiv.org/abs/1606.08415.
         
     | 
| 322 | 
         
            +
             
     | 
| 323 | 
         
            +
                Parameters:
         
     | 
| 324 | 
         
            +
                    dim_in (`int`): The number of channels in the input.
         
     | 
| 325 | 
         
            +
                    dim_out (`int`): The number of channels in the output.
         
     | 
| 326 | 
         
            +
                """
         
     | 
| 327 | 
         
            +
             
     | 
| 328 | 
         
            +
                def __init__(self, dim_in: int, dim_out: int):
         
     | 
| 329 | 
         
            +
                    super().__init__()
         
     | 
| 330 | 
         
            +
                    self.proj = nn.Linear(dim_in, dim_out)
         
     | 
| 331 | 
         
            +
             
     | 
| 332 | 
         
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 333 | 
         
            +
                    x = self.proj(x)
         
     | 
| 334 | 
         
            +
                    return x * torch.sigmoid(1.702 * x)
         
     | 
    	
        tsr/models/transformer/transformer_1d.py
    ADDED
    
    | 
         @@ -0,0 +1,219 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            #
         
     | 
| 15 | 
         
            +
            # --------
         
     | 
| 16 | 
         
            +
            #
         
     | 
| 17 | 
         
            +
            # Modified 2024 by the Tripo AI and Stability AI Team.
         
     | 
| 18 | 
         
            +
            #
         
     | 
| 19 | 
         
            +
            # Copyright (c) 2024 Tripo AI & Stability AI
         
     | 
| 20 | 
         
            +
            #
         
     | 
| 21 | 
         
            +
            # Permission is hereby granted, free of charge, to any person obtaining a copy
         
     | 
| 22 | 
         
            +
            # of this software and associated documentation files (the "Software"), to deal
         
     | 
| 23 | 
         
            +
            # in the Software without restriction, including without limitation the rights
         
     | 
| 24 | 
         
            +
            # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         
     | 
| 25 | 
         
            +
            # copies of the Software, and to permit persons to whom the Software is
         
     | 
| 26 | 
         
            +
            # furnished to do so, subject to the following conditions:
         
     | 
| 27 | 
         
            +
            #
         
     | 
| 28 | 
         
            +
            # The above copyright notice and this permission notice shall be included in all
         
     | 
| 29 | 
         
            +
            # copies or substantial portions of the Software.
         
     | 
| 30 | 
         
            +
            #
         
     | 
| 31 | 
         
            +
            # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         
     | 
| 32 | 
         
            +
            # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         
     | 
| 33 | 
         
            +
            # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         
     | 
| 34 | 
         
            +
            # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         
     | 
| 35 | 
         
            +
            # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         
     | 
| 36 | 
         
            +
            # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         
     | 
| 37 | 
         
            +
            # SOFTWARE.
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 40 | 
         
            +
            from typing import Optional
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            import torch
         
     | 
| 43 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 44 | 
         
            +
            from torch import nn
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            from ...utils import BaseModule
         
     | 
| 47 | 
         
            +
            from .basic_transformer_block import BasicTransformerBlock
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            class Transformer1D(BaseModule):
         
     | 
| 51 | 
         
            +
                @dataclass
         
     | 
| 52 | 
         
            +
                class Config(BaseModule.Config):
         
     | 
| 53 | 
         
            +
                    num_attention_heads: int = 16
         
     | 
| 54 | 
         
            +
                    attention_head_dim: int = 88
         
     | 
| 55 | 
         
            +
                    in_channels: Optional[int] = None
         
     | 
| 56 | 
         
            +
                    out_channels: Optional[int] = None
         
     | 
| 57 | 
         
            +
                    num_layers: int = 1
         
     | 
| 58 | 
         
            +
                    dropout: float = 0.0
         
     | 
| 59 | 
         
            +
                    norm_num_groups: int = 32
         
     | 
| 60 | 
         
            +
                    cross_attention_dim: Optional[int] = None
         
     | 
| 61 | 
         
            +
                    attention_bias: bool = False
         
     | 
| 62 | 
         
            +
                    activation_fn: str = "geglu"
         
     | 
| 63 | 
         
            +
                    only_cross_attention: bool = False
         
     | 
| 64 | 
         
            +
                    double_self_attention: bool = False
         
     | 
| 65 | 
         
            +
                    upcast_attention: bool = False
         
     | 
| 66 | 
         
            +
                    norm_type: str = "layer_norm"
         
     | 
| 67 | 
         
            +
                    norm_elementwise_affine: bool = True
         
     | 
| 68 | 
         
            +
                    gradient_checkpointing: bool = False
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                cfg: Config
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                def configure(self) -> None:
         
     | 
| 73 | 
         
            +
                    self.num_attention_heads = self.cfg.num_attention_heads
         
     | 
| 74 | 
         
            +
                    self.attention_head_dim = self.cfg.attention_head_dim
         
     | 
| 75 | 
         
            +
                    inner_dim = self.num_attention_heads * self.attention_head_dim
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                    linear_cls = nn.Linear
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                    # 2. Define input layers
         
     | 
| 80 | 
         
            +
                    self.in_channels = self.cfg.in_channels
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                    self.norm = torch.nn.GroupNorm(
         
     | 
| 83 | 
         
            +
                        num_groups=self.cfg.norm_num_groups,
         
     | 
| 84 | 
         
            +
                        num_channels=self.cfg.in_channels,
         
     | 
| 85 | 
         
            +
                        eps=1e-6,
         
     | 
| 86 | 
         
            +
                        affine=True,
         
     | 
| 87 | 
         
            +
                    )
         
     | 
| 88 | 
         
            +
                    self.proj_in = linear_cls(self.cfg.in_channels, inner_dim)
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                    # 3. Define transformers blocks
         
     | 
| 91 | 
         
            +
                    self.transformer_blocks = nn.ModuleList(
         
     | 
| 92 | 
         
            +
                        [
         
     | 
| 93 | 
         
            +
                            BasicTransformerBlock(
         
     | 
| 94 | 
         
            +
                                inner_dim,
         
     | 
| 95 | 
         
            +
                                self.num_attention_heads,
         
     | 
| 96 | 
         
            +
                                self.attention_head_dim,
         
     | 
| 97 | 
         
            +
                                dropout=self.cfg.dropout,
         
     | 
| 98 | 
         
            +
                                cross_attention_dim=self.cfg.cross_attention_dim,
         
     | 
| 99 | 
         
            +
                                activation_fn=self.cfg.activation_fn,
         
     | 
| 100 | 
         
            +
                                attention_bias=self.cfg.attention_bias,
         
     | 
| 101 | 
         
            +
                                only_cross_attention=self.cfg.only_cross_attention,
         
     | 
| 102 | 
         
            +
                                double_self_attention=self.cfg.double_self_attention,
         
     | 
| 103 | 
         
            +
                                upcast_attention=self.cfg.upcast_attention,
         
     | 
| 104 | 
         
            +
                                norm_type=self.cfg.norm_type,
         
     | 
| 105 | 
         
            +
                                norm_elementwise_affine=self.cfg.norm_elementwise_affine,
         
     | 
| 106 | 
         
            +
                            )
         
     | 
| 107 | 
         
            +
                            for d in range(self.cfg.num_layers)
         
     | 
| 108 | 
         
            +
                        ]
         
     | 
| 109 | 
         
            +
                    )
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                    # 4. Define output layers
         
     | 
| 112 | 
         
            +
                    self.out_channels = (
         
     | 
| 113 | 
         
            +
                        self.cfg.in_channels
         
     | 
| 114 | 
         
            +
                        if self.cfg.out_channels is None
         
     | 
| 115 | 
         
            +
                        else self.cfg.out_channels
         
     | 
| 116 | 
         
            +
                    )
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                    self.proj_out = linear_cls(inner_dim, self.cfg.in_channels)
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                    self.gradient_checkpointing = self.cfg.gradient_checkpointing
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                def forward(
         
     | 
| 123 | 
         
            +
                    self,
         
     | 
| 124 | 
         
            +
                    hidden_states: torch.Tensor,
         
     | 
| 125 | 
         
            +
                    encoder_hidden_states: Optional[torch.Tensor] = None,
         
     | 
| 126 | 
         
            +
                    attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 127 | 
         
            +
                    encoder_attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 128 | 
         
            +
                ):
         
     | 
| 129 | 
         
            +
                    """
         
     | 
| 130 | 
         
            +
                    The [`Transformer1DModel`] forward method.
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    Args:
         
     | 
| 133 | 
         
            +
                        hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
         
     | 
| 134 | 
         
            +
                            Input `hidden_states`.
         
     | 
| 135 | 
         
            +
                        encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
         
     | 
| 136 | 
         
            +
                            Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
         
     | 
| 137 | 
         
            +
                            self-attention.
         
     | 
| 138 | 
         
            +
                        attention_mask ( `torch.Tensor`, *optional*):
         
     | 
| 139 | 
         
            +
                            An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
         
     | 
| 140 | 
         
            +
                            is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
         
     | 
| 141 | 
         
            +
                            negative values to the attention scores corresponding to "discard" tokens.
         
     | 
| 142 | 
         
            +
                        encoder_attention_mask ( `torch.Tensor`, *optional*):
         
     | 
| 143 | 
         
            +
                            Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                                * Mask `(batch, sequence_length)` True = keep, False = discard.
         
     | 
| 146 | 
         
            +
                                * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                            If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
         
     | 
| 149 | 
         
            +
                            above. This bias will be added to the cross-attention scores.
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                    Returns:
         
     | 
| 152 | 
         
            +
                        torch.FloatTensor
         
     | 
| 153 | 
         
            +
                    """
         
     | 
| 154 | 
         
            +
                    # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
         
     | 
| 155 | 
         
            +
                    #   we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
         
     | 
| 156 | 
         
            +
                    #   we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
         
     | 
| 157 | 
         
            +
                    # expects mask of shape:
         
     | 
| 158 | 
         
            +
                    #   [batch, key_tokens]
         
     | 
| 159 | 
         
            +
                    # adds singleton query_tokens dimension:
         
     | 
| 160 | 
         
            +
                    #   [batch,                    1, key_tokens]
         
     | 
| 161 | 
         
            +
                    # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
         
     | 
| 162 | 
         
            +
                    #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)
         
     | 
| 163 | 
         
            +
                    #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
         
     | 
| 164 | 
         
            +
                    if attention_mask is not None and attention_mask.ndim == 2:
         
     | 
| 165 | 
         
            +
                        # assume that mask is expressed as:
         
     | 
| 166 | 
         
            +
                        #   (1 = keep,      0 = discard)
         
     | 
| 167 | 
         
            +
                        # convert mask into a bias that can be added to attention scores:
         
     | 
| 168 | 
         
            +
                        #       (keep = +0,     discard = -10000.0)
         
     | 
| 169 | 
         
            +
                        attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
         
     | 
| 170 | 
         
            +
                        attention_mask = attention_mask.unsqueeze(1)
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                    # convert encoder_attention_mask to a bias the same way we do for attention_mask
         
     | 
| 173 | 
         
            +
                    if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
         
     | 
| 174 | 
         
            +
                        encoder_attention_mask = (
         
     | 
| 175 | 
         
            +
                            1 - encoder_attention_mask.to(hidden_states.dtype)
         
     | 
| 176 | 
         
            +
                        ) * -10000.0
         
     | 
| 177 | 
         
            +
                        encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                    # 1. Input
         
     | 
| 180 | 
         
            +
                    batch, _, seq_len = hidden_states.shape
         
     | 
| 181 | 
         
            +
                    residual = hidden_states
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                    hidden_states = self.norm(hidden_states)
         
     | 
| 184 | 
         
            +
                    inner_dim = hidden_states.shape[1]
         
     | 
| 185 | 
         
            +
                    hidden_states = hidden_states.permute(0, 2, 1).reshape(
         
     | 
| 186 | 
         
            +
                        batch, seq_len, inner_dim
         
     | 
| 187 | 
         
            +
                    )
         
     | 
| 188 | 
         
            +
                    hidden_states = self.proj_in(hidden_states)
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
                    # 2. Blocks
         
     | 
| 191 | 
         
            +
                    for block in self.transformer_blocks:
         
     | 
| 192 | 
         
            +
                        if self.training and self.gradient_checkpointing:
         
     | 
| 193 | 
         
            +
                            hidden_states = torch.utils.checkpoint.checkpoint(
         
     | 
| 194 | 
         
            +
                                block,
         
     | 
| 195 | 
         
            +
                                hidden_states,
         
     | 
| 196 | 
         
            +
                                attention_mask,
         
     | 
| 197 | 
         
            +
                                encoder_hidden_states,
         
     | 
| 198 | 
         
            +
                                encoder_attention_mask,
         
     | 
| 199 | 
         
            +
                                use_reentrant=False,
         
     | 
| 200 | 
         
            +
                            )
         
     | 
| 201 | 
         
            +
                        else:
         
     | 
| 202 | 
         
            +
                            hidden_states = block(
         
     | 
| 203 | 
         
            +
                                hidden_states,
         
     | 
| 204 | 
         
            +
                                attention_mask=attention_mask,
         
     | 
| 205 | 
         
            +
                                encoder_hidden_states=encoder_hidden_states,
         
     | 
| 206 | 
         
            +
                                encoder_attention_mask=encoder_attention_mask,
         
     | 
| 207 | 
         
            +
                            )
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
                    # 3. Output
         
     | 
| 210 | 
         
            +
                    hidden_states = self.proj_out(hidden_states)
         
     | 
| 211 | 
         
            +
                    hidden_states = (
         
     | 
| 212 | 
         
            +
                        hidden_states.reshape(batch, seq_len, inner_dim)
         
     | 
| 213 | 
         
            +
                        .permute(0, 2, 1)
         
     | 
| 214 | 
         
            +
                        .contiguous()
         
     | 
| 215 | 
         
            +
                    )
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
                    output = hidden_states + residual
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                    return output
         
     | 
    	
        tsr/system.py
    ADDED
    
    | 
         @@ -0,0 +1,203 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import math
         
     | 
| 2 | 
         
            +
            import os
         
     | 
| 3 | 
         
            +
            from dataclasses import dataclass, field
         
     | 
| 4 | 
         
            +
            from typing import List, Union
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            import PIL.Image
         
     | 
| 8 | 
         
            +
            import torch
         
     | 
| 9 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 10 | 
         
            +
            import trimesh
         
     | 
| 11 | 
         
            +
            from einops import rearrange
         
     | 
| 12 | 
         
            +
            from huggingface_hub import hf_hub_download
         
     | 
| 13 | 
         
            +
            from omegaconf import OmegaConf
         
     | 
| 14 | 
         
            +
            from PIL import Image
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            from .models.isosurface import MarchingCubeHelper
         
     | 
| 17 | 
         
            +
            from .utils import (
         
     | 
| 18 | 
         
            +
                BaseModule,
         
     | 
| 19 | 
         
            +
                ImagePreprocessor,
         
     | 
| 20 | 
         
            +
                find_class,
         
     | 
| 21 | 
         
            +
                get_spherical_cameras,
         
     | 
| 22 | 
         
            +
                scale_tensor,
         
     | 
| 23 | 
         
            +
            )
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            class TSR(BaseModule):
         
     | 
| 27 | 
         
            +
                @dataclass
         
     | 
| 28 | 
         
            +
                class Config(BaseModule.Config):
         
     | 
| 29 | 
         
            +
                    cond_image_size: int
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                    image_tokenizer_cls: str
         
     | 
| 32 | 
         
            +
                    image_tokenizer: dict
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                    tokenizer_cls: str
         
     | 
| 35 | 
         
            +
                    tokenizer: dict
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                    backbone_cls: str
         
     | 
| 38 | 
         
            +
                    backbone: dict
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                    post_processor_cls: str
         
     | 
| 41 | 
         
            +
                    post_processor: dict
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                    decoder_cls: str
         
     | 
| 44 | 
         
            +
                    decoder: dict
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                    renderer_cls: str
         
     | 
| 47 | 
         
            +
                    renderer: dict
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                cfg: Config
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                @classmethod
         
     | 
| 52 | 
         
            +
                def from_pretrained(
         
     | 
| 53 | 
         
            +
                    cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
         
     | 
| 54 | 
         
            +
                ):
         
     | 
| 55 | 
         
            +
                    if os.path.isdir(pretrained_model_name_or_path):
         
     | 
| 56 | 
         
            +
                        config_path = os.path.join(pretrained_model_name_or_path, config_name)
         
     | 
| 57 | 
         
            +
                        weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
         
     | 
| 58 | 
         
            +
                    else:
         
     | 
| 59 | 
         
            +
                        config_path = hf_hub_download(
         
     | 
| 60 | 
         
            +
                            repo_id=pretrained_model_name_or_path, filename=config_name
         
     | 
| 61 | 
         
            +
                        )
         
     | 
| 62 | 
         
            +
                        weight_path = hf_hub_download(
         
     | 
| 63 | 
         
            +
                            repo_id=pretrained_model_name_or_path, filename=weight_name
         
     | 
| 64 | 
         
            +
                        )
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    cfg = OmegaConf.load(config_path)
         
     | 
| 67 | 
         
            +
                    OmegaConf.resolve(cfg)
         
     | 
| 68 | 
         
            +
                    model = cls(cfg)
         
     | 
| 69 | 
         
            +
                    ckpt = torch.load(weight_path, map_location="cpu")
         
     | 
| 70 | 
         
            +
                    model.load_state_dict(ckpt)
         
     | 
| 71 | 
         
            +
                    return model
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                def configure(self):
         
     | 
| 74 | 
         
            +
                    self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
         
     | 
| 75 | 
         
            +
                        self.cfg.image_tokenizer
         
     | 
| 76 | 
         
            +
                    )
         
     | 
| 77 | 
         
            +
                    self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
         
     | 
| 78 | 
         
            +
                    self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
         
     | 
| 79 | 
         
            +
                    self.post_processor = find_class(self.cfg.post_processor_cls)(
         
     | 
| 80 | 
         
            +
                        self.cfg.post_processor
         
     | 
| 81 | 
         
            +
                    )
         
     | 
| 82 | 
         
            +
                    self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
         
     | 
| 83 | 
         
            +
                    self.renderer = find_class(self.cfg.renderer_cls)(self.cfg.renderer)
         
     | 
| 84 | 
         
            +
                    self.image_processor = ImagePreprocessor()
         
     | 
| 85 | 
         
            +
                    self.isosurface_helper = None
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                def forward(
         
     | 
| 88 | 
         
            +
                    self,
         
     | 
| 89 | 
         
            +
                    image: Union[
         
     | 
| 90 | 
         
            +
                        PIL.Image.Image,
         
     | 
| 91 | 
         
            +
                        np.ndarray,
         
     | 
| 92 | 
         
            +
                        torch.FloatTensor,
         
     | 
| 93 | 
         
            +
                        List[PIL.Image.Image],
         
     | 
| 94 | 
         
            +
                        List[np.ndarray],
         
     | 
| 95 | 
         
            +
                        List[torch.FloatTensor],
         
     | 
| 96 | 
         
            +
                    ],
         
     | 
| 97 | 
         
            +
                    device: str,
         
     | 
| 98 | 
         
            +
                ) -> torch.FloatTensor:
         
     | 
| 99 | 
         
            +
                    rgb_cond = self.image_processor(image, self.cfg.cond_image_size)[:, None].to(
         
     | 
| 100 | 
         
            +
                        device
         
     | 
| 101 | 
         
            +
                    )
         
     | 
| 102 | 
         
            +
                    batch_size = rgb_cond.shape[0]
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                    input_image_tokens: torch.Tensor = self.image_tokenizer(
         
     | 
| 105 | 
         
            +
                        rearrange(rgb_cond, "B Nv H W C -> B Nv C H W", Nv=1),
         
     | 
| 106 | 
         
            +
                    )
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                    input_image_tokens = rearrange(
         
     | 
| 109 | 
         
            +
                        input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=1
         
     | 
| 110 | 
         
            +
                    )
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    tokens: torch.Tensor = self.tokenizer(batch_size)
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                    tokens = self.backbone(
         
     | 
| 115 | 
         
            +
                        tokens,
         
     | 
| 116 | 
         
            +
                        encoder_hidden_states=input_image_tokens,
         
     | 
| 117 | 
         
            +
                    )
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                    scene_codes = self.post_processor(self.tokenizer.detokenize(tokens))
         
     | 
| 120 | 
         
            +
                    return scene_codes
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                def render(
         
     | 
| 123 | 
         
            +
                    self,
         
     | 
| 124 | 
         
            +
                    scene_codes,
         
     | 
| 125 | 
         
            +
                    n_views: int,
         
     | 
| 126 | 
         
            +
                    elevation_deg: float = 0.0,
         
     | 
| 127 | 
         
            +
                    camera_distance: float = 1.9,
         
     | 
| 128 | 
         
            +
                    fovy_deg: float = 40.0,
         
     | 
| 129 | 
         
            +
                    height: int = 256,
         
     | 
| 130 | 
         
            +
                    width: int = 256,
         
     | 
| 131 | 
         
            +
                    return_type: str = "pil",
         
     | 
| 132 | 
         
            +
                ):
         
     | 
| 133 | 
         
            +
                    rays_o, rays_d = get_spherical_cameras(
         
     | 
| 134 | 
         
            +
                        n_views, elevation_deg, camera_distance, fovy_deg, height, width
         
     | 
| 135 | 
         
            +
                    )
         
     | 
| 136 | 
         
            +
                    rays_o, rays_d = rays_o.to(scene_codes.device), rays_d.to(scene_codes.device)
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                    def process_output(image: torch.FloatTensor):
         
     | 
| 139 | 
         
            +
                        if return_type == "pt":
         
     | 
| 140 | 
         
            +
                            return image
         
     | 
| 141 | 
         
            +
                        elif return_type == "np":
         
     | 
| 142 | 
         
            +
                            return image.detach().cpu().numpy()
         
     | 
| 143 | 
         
            +
                        elif return_type == "pil":
         
     | 
| 144 | 
         
            +
                            return Image.fromarray(
         
     | 
| 145 | 
         
            +
                                (image.detach().cpu().numpy() * 255.0).astype(np.uint8)
         
     | 
| 146 | 
         
            +
                            )
         
     | 
| 147 | 
         
            +
                        else:
         
     | 
| 148 | 
         
            +
                            raise NotImplementedError
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                    images = []
         
     | 
| 151 | 
         
            +
                    for scene_code in scene_codes:
         
     | 
| 152 | 
         
            +
                        images_ = []
         
     | 
| 153 | 
         
            +
                        for i in range(n_views):
         
     | 
| 154 | 
         
            +
                            with torch.no_grad():
         
     | 
| 155 | 
         
            +
                                image = self.renderer(
         
     | 
| 156 | 
         
            +
                                    self.decoder, scene_code, rays_o[i], rays_d[i]
         
     | 
| 157 | 
         
            +
                                )
         
     | 
| 158 | 
         
            +
                            images_.append(process_output(image))
         
     | 
| 159 | 
         
            +
                        images.append(images_)
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                    return images
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                def set_marching_cubes_resolution(self, resolution: int):
         
     | 
| 164 | 
         
            +
                    if (
         
     | 
| 165 | 
         
            +
                        self.isosurface_helper is not None
         
     | 
| 166 | 
         
            +
                        and self.isosurface_helper.resolution == resolution
         
     | 
| 167 | 
         
            +
                    ):
         
     | 
| 168 | 
         
            +
                        return
         
     | 
| 169 | 
         
            +
                    self.isosurface_helper = MarchingCubeHelper(resolution)
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                def extract_mesh(self, scene_codes, resolution: int = 256, threshold: float = 25.0):
         
     | 
| 172 | 
         
            +
                    self.set_marching_cubes_resolution(resolution)
         
     | 
| 173 | 
         
            +
                    meshes = []
         
     | 
| 174 | 
         
            +
                    for scene_code in scene_codes:
         
     | 
| 175 | 
         
            +
                        with torch.no_grad():
         
     | 
| 176 | 
         
            +
                            density = self.renderer.query_triplane(
         
     | 
| 177 | 
         
            +
                                self.decoder,
         
     | 
| 178 | 
         
            +
                                scale_tensor(
         
     | 
| 179 | 
         
            +
                                    self.isosurface_helper.grid_vertices.to(scene_codes.device),
         
     | 
| 180 | 
         
            +
                                    self.isosurface_helper.points_range,
         
     | 
| 181 | 
         
            +
                                    (-self.renderer.cfg.radius, self.renderer.cfg.radius),
         
     | 
| 182 | 
         
            +
                                ),
         
     | 
| 183 | 
         
            +
                                scene_code,
         
     | 
| 184 | 
         
            +
                            )["density_act"]
         
     | 
| 185 | 
         
            +
                        v_pos, t_pos_idx = self.isosurface_helper(-(density - threshold))
         
     | 
| 186 | 
         
            +
                        v_pos = scale_tensor(
         
     | 
| 187 | 
         
            +
                            v_pos,
         
     | 
| 188 | 
         
            +
                            self.isosurface_helper.points_range,
         
     | 
| 189 | 
         
            +
                            (-self.renderer.cfg.radius, self.renderer.cfg.radius),
         
     | 
| 190 | 
         
            +
                        )
         
     | 
| 191 | 
         
            +
                        with torch.no_grad():
         
     | 
| 192 | 
         
            +
                            color = self.renderer.query_triplane(
         
     | 
| 193 | 
         
            +
                                self.decoder,
         
     | 
| 194 | 
         
            +
                                v_pos,
         
     | 
| 195 | 
         
            +
                                scene_code,
         
     | 
| 196 | 
         
            +
                            )["color"]
         
     | 
| 197 | 
         
            +
                        mesh = trimesh.Trimesh(
         
     | 
| 198 | 
         
            +
                            vertices=v_pos.cpu().numpy(),
         
     | 
| 199 | 
         
            +
                            faces=t_pos_idx.cpu().numpy(),
         
     | 
| 200 | 
         
            +
                            vertex_colors=color.cpu().numpy(),
         
     | 
| 201 | 
         
            +
                        )
         
     | 
| 202 | 
         
            +
                        meshes.append(mesh)
         
     | 
| 203 | 
         
            +
                    return meshes
         
     | 
    	
        tsr/utils.py
    ADDED
    
    | 
         @@ -0,0 +1,474 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import importlib
         
     | 
| 2 | 
         
            +
            import math
         
     | 
| 3 | 
         
            +
            from collections import defaultdict
         
     | 
| 4 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 5 | 
         
            +
            from typing import Any, Callable, Dict, List, Optional, Tuple, Union
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import imageio
         
     | 
| 8 | 
         
            +
            import numpy as np
         
     | 
| 9 | 
         
            +
            import PIL.Image
         
     | 
| 10 | 
         
            +
            import rembg
         
     | 
| 11 | 
         
            +
            import torch
         
     | 
| 12 | 
         
            +
            import torch.nn as nn
         
     | 
| 13 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 14 | 
         
            +
            import trimesh
         
     | 
| 15 | 
         
            +
            from omegaconf import DictConfig, OmegaConf
         
     | 
| 16 | 
         
            +
            from PIL import Image
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
         
     | 
| 20 | 
         
            +
                scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg)
         
     | 
| 21 | 
         
            +
                return scfg
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            def find_class(cls_string):
         
     | 
| 25 | 
         
            +
                module_string = ".".join(cls_string.split(".")[:-1])
         
     | 
| 26 | 
         
            +
                cls_name = cls_string.split(".")[-1]
         
     | 
| 27 | 
         
            +
                module = importlib.import_module(module_string, package=None)
         
     | 
| 28 | 
         
            +
                cls = getattr(module, cls_name)
         
     | 
| 29 | 
         
            +
                return cls
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            def get_intrinsic_from_fov(fov, H, W, bs=-1):
         
     | 
| 33 | 
         
            +
                focal_length = 0.5 * H / np.tan(0.5 * fov)
         
     | 
| 34 | 
         
            +
                intrinsic = np.identity(3, dtype=np.float32)
         
     | 
| 35 | 
         
            +
                intrinsic[0, 0] = focal_length
         
     | 
| 36 | 
         
            +
                intrinsic[1, 1] = focal_length
         
     | 
| 37 | 
         
            +
                intrinsic[0, 2] = W / 2.0
         
     | 
| 38 | 
         
            +
                intrinsic[1, 2] = H / 2.0
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                if bs > 0:
         
     | 
| 41 | 
         
            +
                    intrinsic = intrinsic[None].repeat(bs, axis=0)
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                return torch.from_numpy(intrinsic)
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            class BaseModule(nn.Module):
         
     | 
| 47 | 
         
            +
                @dataclass
         
     | 
| 48 | 
         
            +
                class Config:
         
     | 
| 49 | 
         
            +
                    pass
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                cfg: Config  # add this to every subclass of BaseModule to enable static type checking
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                def __init__(
         
     | 
| 54 | 
         
            +
                    self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
         
     | 
| 55 | 
         
            +
                ) -> None:
         
     | 
| 56 | 
         
            +
                    super().__init__()
         
     | 
| 57 | 
         
            +
                    self.cfg = parse_structured(self.Config, cfg)
         
     | 
| 58 | 
         
            +
                    self.configure(*args, **kwargs)
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                def configure(self, *args, **kwargs) -> None:
         
     | 
| 61 | 
         
            +
                    raise NotImplementedError
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            class ImagePreprocessor:
         
     | 
| 65 | 
         
            +
                def convert_and_resize(
         
     | 
| 66 | 
         
            +
                    self,
         
     | 
| 67 | 
         
            +
                    image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
         
     | 
| 68 | 
         
            +
                    size: int,
         
     | 
| 69 | 
         
            +
                ):
         
     | 
| 70 | 
         
            +
                    if isinstance(image, PIL.Image.Image):
         
     | 
| 71 | 
         
            +
                        image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
         
     | 
| 72 | 
         
            +
                    elif isinstance(image, np.ndarray):
         
     | 
| 73 | 
         
            +
                        if image.dtype == np.uint8:
         
     | 
| 74 | 
         
            +
                            image = torch.from_numpy(image.astype(np.float32) / 255.0)
         
     | 
| 75 | 
         
            +
                        else:
         
     | 
| 76 | 
         
            +
                            image = torch.from_numpy(image)
         
     | 
| 77 | 
         
            +
                    elif isinstance(image, torch.Tensor):
         
     | 
| 78 | 
         
            +
                        pass
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    batched = image.ndim == 4
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                    if not batched:
         
     | 
| 83 | 
         
            +
                        image = image[None, ...]
         
     | 
| 84 | 
         
            +
                    image = F.interpolate(
         
     | 
| 85 | 
         
            +
                        image.permute(0, 3, 1, 2),
         
     | 
| 86 | 
         
            +
                        (size, size),
         
     | 
| 87 | 
         
            +
                        mode="bilinear",
         
     | 
| 88 | 
         
            +
                        align_corners=False,
         
     | 
| 89 | 
         
            +
                        antialias=True,
         
     | 
| 90 | 
         
            +
                    ).permute(0, 2, 3, 1)
         
     | 
| 91 | 
         
            +
                    if not batched:
         
     | 
| 92 | 
         
            +
                        image = image[0]
         
     | 
| 93 | 
         
            +
                    return image
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                def __call__(
         
     | 
| 96 | 
         
            +
                    self,
         
     | 
| 97 | 
         
            +
                    image: Union[
         
     | 
| 98 | 
         
            +
                        PIL.Image.Image,
         
     | 
| 99 | 
         
            +
                        np.ndarray,
         
     | 
| 100 | 
         
            +
                        torch.FloatTensor,
         
     | 
| 101 | 
         
            +
                        List[PIL.Image.Image],
         
     | 
| 102 | 
         
            +
                        List[np.ndarray],
         
     | 
| 103 | 
         
            +
                        List[torch.FloatTensor],
         
     | 
| 104 | 
         
            +
                    ],
         
     | 
| 105 | 
         
            +
                    size: int,
         
     | 
| 106 | 
         
            +
                ) -> Any:
         
     | 
| 107 | 
         
            +
                    if isinstance(image, (np.ndarray, torch.FloatTensor)) and image.ndim == 4:
         
     | 
| 108 | 
         
            +
                        image = self.convert_and_resize(image, size)
         
     | 
| 109 | 
         
            +
                    else:
         
     | 
| 110 | 
         
            +
                        if not isinstance(image, list):
         
     | 
| 111 | 
         
            +
                            image = [image]
         
     | 
| 112 | 
         
            +
                        image = [self.convert_and_resize(im, size) for im in image]
         
     | 
| 113 | 
         
            +
                        image = torch.stack(image, dim=0)
         
     | 
| 114 | 
         
            +
                    return image
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
            def rays_intersect_bbox(
         
     | 
| 118 | 
         
            +
                rays_o: torch.Tensor,
         
     | 
| 119 | 
         
            +
                rays_d: torch.Tensor,
         
     | 
| 120 | 
         
            +
                radius: float,
         
     | 
| 121 | 
         
            +
                near: float = 0.0,
         
     | 
| 122 | 
         
            +
                valid_thresh: float = 0.01,
         
     | 
| 123 | 
         
            +
            ):
         
     | 
| 124 | 
         
            +
                input_shape = rays_o.shape[:-1]
         
     | 
| 125 | 
         
            +
                rays_o, rays_d = rays_o.view(-1, 3), rays_d.view(-1, 3)
         
     | 
| 126 | 
         
            +
                rays_d_valid = torch.where(
         
     | 
| 127 | 
         
            +
                    rays_d.abs() < 1e-6, torch.full_like(rays_d, 1e-6), rays_d
         
     | 
| 128 | 
         
            +
                )
         
     | 
| 129 | 
         
            +
                if type(radius) in [int, float]:
         
     | 
| 130 | 
         
            +
                    radius = torch.FloatTensor(
         
     | 
| 131 | 
         
            +
                        [[-radius, radius], [-radius, radius], [-radius, radius]]
         
     | 
| 132 | 
         
            +
                    ).to(rays_o.device)
         
     | 
| 133 | 
         
            +
                radius = (
         
     | 
| 134 | 
         
            +
                    1.0 - 1.0e-3
         
     | 
| 135 | 
         
            +
                ) * radius  # tighten the radius to make sure the intersection point lies in the bounding box
         
     | 
| 136 | 
         
            +
                interx0 = (radius[..., 1] - rays_o) / rays_d_valid
         
     | 
| 137 | 
         
            +
                interx1 = (radius[..., 0] - rays_o) / rays_d_valid
         
     | 
| 138 | 
         
            +
                t_near = torch.minimum(interx0, interx1).amax(dim=-1).clamp_min(near)
         
     | 
| 139 | 
         
            +
                t_far = torch.maximum(interx0, interx1).amin(dim=-1)
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                # check wheter a ray intersects the bbox or not
         
     | 
| 142 | 
         
            +
                rays_valid = t_far - t_near > valid_thresh
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                t_near[torch.where(~rays_valid)] = 0.0
         
     | 
| 145 | 
         
            +
                t_far[torch.where(~rays_valid)] = 0.0
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
                t_near = t_near.view(*input_shape, 1)
         
     | 
| 148 | 
         
            +
                t_far = t_far.view(*input_shape, 1)
         
     | 
| 149 | 
         
            +
                rays_valid = rays_valid.view(*input_shape)
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                return t_near, t_far, rays_valid
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
            def chunk_batch(func: Callable, chunk_size: int, *args, **kwargs) -> Any:
         
     | 
| 155 | 
         
            +
                if chunk_size <= 0:
         
     | 
| 156 | 
         
            +
                    return func(*args, **kwargs)
         
     | 
| 157 | 
         
            +
                B = None
         
     | 
| 158 | 
         
            +
                for arg in list(args) + list(kwargs.values()):
         
     | 
| 159 | 
         
            +
                    if isinstance(arg, torch.Tensor):
         
     | 
| 160 | 
         
            +
                        B = arg.shape[0]
         
     | 
| 161 | 
         
            +
                        break
         
     | 
| 162 | 
         
            +
                assert (
         
     | 
| 163 | 
         
            +
                    B is not None
         
     | 
| 164 | 
         
            +
                ), "No tensor found in args or kwargs, cannot determine batch size."
         
     | 
| 165 | 
         
            +
                out = defaultdict(list)
         
     | 
| 166 | 
         
            +
                out_type = None
         
     | 
| 167 | 
         
            +
                # max(1, B) to support B == 0
         
     | 
| 168 | 
         
            +
                for i in range(0, max(1, B), chunk_size):
         
     | 
| 169 | 
         
            +
                    out_chunk = func(
         
     | 
| 170 | 
         
            +
                        *[
         
     | 
| 171 | 
         
            +
                            arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
         
     | 
| 172 | 
         
            +
                            for arg in args
         
     | 
| 173 | 
         
            +
                        ],
         
     | 
| 174 | 
         
            +
                        **{
         
     | 
| 175 | 
         
            +
                            k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
         
     | 
| 176 | 
         
            +
                            for k, arg in kwargs.items()
         
     | 
| 177 | 
         
            +
                        },
         
     | 
| 178 | 
         
            +
                    )
         
     | 
| 179 | 
         
            +
                    if out_chunk is None:
         
     | 
| 180 | 
         
            +
                        continue
         
     | 
| 181 | 
         
            +
                    out_type = type(out_chunk)
         
     | 
| 182 | 
         
            +
                    if isinstance(out_chunk, torch.Tensor):
         
     | 
| 183 | 
         
            +
                        out_chunk = {0: out_chunk}
         
     | 
| 184 | 
         
            +
                    elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list):
         
     | 
| 185 | 
         
            +
                        chunk_length = len(out_chunk)
         
     | 
| 186 | 
         
            +
                        out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)}
         
     | 
| 187 | 
         
            +
                    elif isinstance(out_chunk, dict):
         
     | 
| 188 | 
         
            +
                        pass
         
     | 
| 189 | 
         
            +
                    else:
         
     | 
| 190 | 
         
            +
                        print(
         
     | 
| 191 | 
         
            +
                            f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}."
         
     | 
| 192 | 
         
            +
                        )
         
     | 
| 193 | 
         
            +
                        exit(1)
         
     | 
| 194 | 
         
            +
                    for k, v in out_chunk.items():
         
     | 
| 195 | 
         
            +
                        v = v if torch.is_grad_enabled() else v.detach()
         
     | 
| 196 | 
         
            +
                        out[k].append(v)
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                if out_type is None:
         
     | 
| 199 | 
         
            +
                    return None
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                out_merged: Dict[Any, Optional[torch.Tensor]] = {}
         
     | 
| 202 | 
         
            +
                for k, v in out.items():
         
     | 
| 203 | 
         
            +
                    if all([vv is None for vv in v]):
         
     | 
| 204 | 
         
            +
                        # allow None in return value
         
     | 
| 205 | 
         
            +
                        out_merged[k] = None
         
     | 
| 206 | 
         
            +
                    elif all([isinstance(vv, torch.Tensor) for vv in v]):
         
     | 
| 207 | 
         
            +
                        out_merged[k] = torch.cat(v, dim=0)
         
     | 
| 208 | 
         
            +
                    else:
         
     | 
| 209 | 
         
            +
                        raise TypeError(
         
     | 
| 210 | 
         
            +
                            f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}"
         
     | 
| 211 | 
         
            +
                        )
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
                if out_type is torch.Tensor:
         
     | 
| 214 | 
         
            +
                    return out_merged[0]
         
     | 
| 215 | 
         
            +
                elif out_type in [tuple, list]:
         
     | 
| 216 | 
         
            +
                    return out_type([out_merged[i] for i in range(chunk_length)])
         
     | 
| 217 | 
         
            +
                elif out_type is dict:
         
     | 
| 218 | 
         
            +
                    return out_merged
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
            ValidScale = Union[Tuple[float, float], torch.FloatTensor]
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
            def scale_tensor(dat: torch.FloatTensor, inp_scale: ValidScale, tgt_scale: ValidScale):
         
     | 
| 225 | 
         
            +
                if inp_scale is None:
         
     | 
| 226 | 
         
            +
                    inp_scale = (0, 1)
         
     | 
| 227 | 
         
            +
                if tgt_scale is None:
         
     | 
| 228 | 
         
            +
                    tgt_scale = (0, 1)
         
     | 
| 229 | 
         
            +
                if isinstance(tgt_scale, torch.FloatTensor):
         
     | 
| 230 | 
         
            +
                    assert dat.shape[-1] == tgt_scale.shape[-1]
         
     | 
| 231 | 
         
            +
                dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
         
     | 
| 232 | 
         
            +
                dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
         
     | 
| 233 | 
         
            +
                return dat
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
            def get_activation(name) -> Callable:
         
     | 
| 237 | 
         
            +
                if name is None:
         
     | 
| 238 | 
         
            +
                    return lambda x: x
         
     | 
| 239 | 
         
            +
                name = name.lower()
         
     | 
| 240 | 
         
            +
                if name == "none":
         
     | 
| 241 | 
         
            +
                    return lambda x: x
         
     | 
| 242 | 
         
            +
                elif name == "exp":
         
     | 
| 243 | 
         
            +
                    return lambda x: torch.exp(x)
         
     | 
| 244 | 
         
            +
                elif name == "sigmoid":
         
     | 
| 245 | 
         
            +
                    return lambda x: torch.sigmoid(x)
         
     | 
| 246 | 
         
            +
                elif name == "tanh":
         
     | 
| 247 | 
         
            +
                    return lambda x: torch.tanh(x)
         
     | 
| 248 | 
         
            +
                elif name == "softplus":
         
     | 
| 249 | 
         
            +
                    return lambda x: F.softplus(x)
         
     | 
| 250 | 
         
            +
                else:
         
     | 
| 251 | 
         
            +
                    try:
         
     | 
| 252 | 
         
            +
                        return getattr(F, name)
         
     | 
| 253 | 
         
            +
                    except AttributeError:
         
     | 
| 254 | 
         
            +
                        raise ValueError(f"Unknown activation function: {name}")
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
            def get_ray_directions(
         
     | 
| 258 | 
         
            +
                H: int,
         
     | 
| 259 | 
         
            +
                W: int,
         
     | 
| 260 | 
         
            +
                focal: Union[float, Tuple[float, float]],
         
     | 
| 261 | 
         
            +
                principal: Optional[Tuple[float, float]] = None,
         
     | 
| 262 | 
         
            +
                use_pixel_centers: bool = True,
         
     | 
| 263 | 
         
            +
                normalize: bool = True,
         
     | 
| 264 | 
         
            +
            ) -> torch.FloatTensor:
         
     | 
| 265 | 
         
            +
                """
         
     | 
| 266 | 
         
            +
                Get ray directions for all pixels in camera coordinate.
         
     | 
| 267 | 
         
            +
                Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
         
     | 
| 268 | 
         
            +
                           ray-tracing-generating-camera-rays/standard-coordinate-systems
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
                Inputs:
         
     | 
| 271 | 
         
            +
                    H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers
         
     | 
| 272 | 
         
            +
                Outputs:
         
     | 
| 273 | 
         
            +
                    directions: (H, W, 3), the direction of the rays in camera coordinate
         
     | 
| 274 | 
         
            +
                """
         
     | 
| 275 | 
         
            +
                pixel_center = 0.5 if use_pixel_centers else 0
         
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
                if isinstance(focal, float):
         
     | 
| 278 | 
         
            +
                    fx, fy = focal, focal
         
     | 
| 279 | 
         
            +
                    cx, cy = W / 2, H / 2
         
     | 
| 280 | 
         
            +
                else:
         
     | 
| 281 | 
         
            +
                    fx, fy = focal
         
     | 
| 282 | 
         
            +
                    assert principal is not None
         
     | 
| 283 | 
         
            +
                    cx, cy = principal
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
                i, j = torch.meshgrid(
         
     | 
| 286 | 
         
            +
                    torch.arange(W, dtype=torch.float32) + pixel_center,
         
     | 
| 287 | 
         
            +
                    torch.arange(H, dtype=torch.float32) + pixel_center,
         
     | 
| 288 | 
         
            +
                    indexing="xy",
         
     | 
| 289 | 
         
            +
                )
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                directions = torch.stack([(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1)
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
                if normalize:
         
     | 
| 294 | 
         
            +
                    directions = F.normalize(directions, dim=-1)
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
                return directions
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
             
     | 
| 299 | 
         
            +
            def get_rays(
         
     | 
| 300 | 
         
            +
                directions,
         
     | 
| 301 | 
         
            +
                c2w,
         
     | 
| 302 | 
         
            +
                keepdim=False,
         
     | 
| 303 | 
         
            +
                normalize=False,
         
     | 
| 304 | 
         
            +
            ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
         
     | 
| 305 | 
         
            +
                # Rotate ray directions from camera coordinate to the world coordinate
         
     | 
| 306 | 
         
            +
                assert directions.shape[-1] == 3
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
                if directions.ndim == 2:  # (N_rays, 3)
         
     | 
| 309 | 
         
            +
                    if c2w.ndim == 2:  # (4, 4)
         
     | 
| 310 | 
         
            +
                        c2w = c2w[None, :, :]
         
     | 
| 311 | 
         
            +
                    assert c2w.ndim == 3  # (N_rays, 4, 4) or (1, 4, 4)
         
     | 
| 312 | 
         
            +
                    rays_d = (directions[:, None, :] * c2w[:, :3, :3]).sum(-1)  # (N_rays, 3)
         
     | 
| 313 | 
         
            +
                    rays_o = c2w[:, :3, 3].expand(rays_d.shape)
         
     | 
| 314 | 
         
            +
                elif directions.ndim == 3:  # (H, W, 3)
         
     | 
| 315 | 
         
            +
                    assert c2w.ndim in [2, 3]
         
     | 
| 316 | 
         
            +
                    if c2w.ndim == 2:  # (4, 4)
         
     | 
| 317 | 
         
            +
                        rays_d = (directions[:, :, None, :] * c2w[None, None, :3, :3]).sum(
         
     | 
| 318 | 
         
            +
                            -1
         
     | 
| 319 | 
         
            +
                        )  # (H, W, 3)
         
     | 
| 320 | 
         
            +
                        rays_o = c2w[None, None, :3, 3].expand(rays_d.shape)
         
     | 
| 321 | 
         
            +
                    elif c2w.ndim == 3:  # (B, 4, 4)
         
     | 
| 322 | 
         
            +
                        rays_d = (directions[None, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
         
     | 
| 323 | 
         
            +
                            -1
         
     | 
| 324 | 
         
            +
                        )  # (B, H, W, 3)
         
     | 
| 325 | 
         
            +
                        rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
         
     | 
| 326 | 
         
            +
                elif directions.ndim == 4:  # (B, H, W, 3)
         
     | 
| 327 | 
         
            +
                    assert c2w.ndim == 3  # (B, 4, 4)
         
     | 
| 328 | 
         
            +
                    rays_d = (directions[:, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
         
     | 
| 329 | 
         
            +
                        -1
         
     | 
| 330 | 
         
            +
                    )  # (B, H, W, 3)
         
     | 
| 331 | 
         
            +
                    rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
         
     | 
| 332 | 
         
            +
             
     | 
| 333 | 
         
            +
                if normalize:
         
     | 
| 334 | 
         
            +
                    rays_d = F.normalize(rays_d, dim=-1)
         
     | 
| 335 | 
         
            +
                if not keepdim:
         
     | 
| 336 | 
         
            +
                    rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)
         
     | 
| 337 | 
         
            +
             
     | 
| 338 | 
         
            +
                return rays_o, rays_d
         
     | 
| 339 | 
         
            +
             
     | 
| 340 | 
         
            +
             
     | 
| 341 | 
         
            +
            def get_spherical_cameras(
         
     | 
| 342 | 
         
            +
                n_views: int,
         
     | 
| 343 | 
         
            +
                elevation_deg: float,
         
     | 
| 344 | 
         
            +
                camera_distance: float,
         
     | 
| 345 | 
         
            +
                fovy_deg: float,
         
     | 
| 346 | 
         
            +
                height: int,
         
     | 
| 347 | 
         
            +
                width: int,
         
     | 
| 348 | 
         
            +
            ):
         
     | 
| 349 | 
         
            +
                azimuth_deg = torch.linspace(0, 360.0, n_views + 1)[:n_views]
         
     | 
| 350 | 
         
            +
                elevation_deg = torch.full_like(azimuth_deg, elevation_deg)
         
     | 
| 351 | 
         
            +
                camera_distances = torch.full_like(elevation_deg, camera_distance)
         
     | 
| 352 | 
         
            +
             
     | 
| 353 | 
         
            +
                elevation = elevation_deg * math.pi / 180
         
     | 
| 354 | 
         
            +
                azimuth = azimuth_deg * math.pi / 180
         
     | 
| 355 | 
         
            +
             
     | 
| 356 | 
         
            +
                # convert spherical coordinates to cartesian coordinates
         
     | 
| 357 | 
         
            +
                # right hand coordinate system, x back, y right, z up
         
     | 
| 358 | 
         
            +
                # elevation in (-90, 90), azimuth from +x to +y in (-180, 180)
         
     | 
| 359 | 
         
            +
                camera_positions = torch.stack(
         
     | 
| 360 | 
         
            +
                    [
         
     | 
| 361 | 
         
            +
                        camera_distances * torch.cos(elevation) * torch.cos(azimuth),
         
     | 
| 362 | 
         
            +
                        camera_distances * torch.cos(elevation) * torch.sin(azimuth),
         
     | 
| 363 | 
         
            +
                        camera_distances * torch.sin(elevation),
         
     | 
| 364 | 
         
            +
                    ],
         
     | 
| 365 | 
         
            +
                    dim=-1,
         
     | 
| 366 | 
         
            +
                )
         
     | 
| 367 | 
         
            +
             
     | 
| 368 | 
         
            +
                # default scene center at origin
         
     | 
| 369 | 
         
            +
                center = torch.zeros_like(camera_positions)
         
     | 
| 370 | 
         
            +
                # default camera up direction as +z
         
     | 
| 371 | 
         
            +
                up = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None, :].repeat(n_views, 1)
         
     | 
| 372 | 
         
            +
             
     | 
| 373 | 
         
            +
                fovy = torch.full_like(elevation_deg, fovy_deg) * math.pi / 180
         
     | 
| 374 | 
         
            +
             
     | 
| 375 | 
         
            +
                lookat = F.normalize(center - camera_positions, dim=-1)
         
     | 
| 376 | 
         
            +
                right = F.normalize(torch.cross(lookat, up), dim=-1)
         
     | 
| 377 | 
         
            +
                up = F.normalize(torch.cross(right, lookat), dim=-1)
         
     | 
| 378 | 
         
            +
                c2w3x4 = torch.cat(
         
     | 
| 379 | 
         
            +
                    [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
         
     | 
| 380 | 
         
            +
                    dim=-1,
         
     | 
| 381 | 
         
            +
                )
         
     | 
| 382 | 
         
            +
                c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1)
         
     | 
| 383 | 
         
            +
                c2w[:, 3, 3] = 1.0
         
     | 
| 384 | 
         
            +
             
     | 
| 385 | 
         
            +
                # get directions by dividing directions_unit_focal by focal length
         
     | 
| 386 | 
         
            +
                focal_length = 0.5 * height / torch.tan(0.5 * fovy)
         
     | 
| 387 | 
         
            +
                directions_unit_focal = get_ray_directions(
         
     | 
| 388 | 
         
            +
                    H=height,
         
     | 
| 389 | 
         
            +
                    W=width,
         
     | 
| 390 | 
         
            +
                    focal=1.0,
         
     | 
| 391 | 
         
            +
                )
         
     | 
| 392 | 
         
            +
                directions = directions_unit_focal[None, :, :, :].repeat(n_views, 1, 1, 1)
         
     | 
| 393 | 
         
            +
                directions[:, :, :, :2] = (
         
     | 
| 394 | 
         
            +
                    directions[:, :, :, :2] / focal_length[:, None, None, None]
         
     | 
| 395 | 
         
            +
                )
         
     | 
| 396 | 
         
            +
                # must use normalize=True to normalize directions here
         
     | 
| 397 | 
         
            +
                rays_o, rays_d = get_rays(directions, c2w, keepdim=True, normalize=True)
         
     | 
| 398 | 
         
            +
             
     | 
| 399 | 
         
            +
                return rays_o, rays_d
         
     | 
| 400 | 
         
            +
             
     | 
| 401 | 
         
            +
             
     | 
| 402 | 
         
            +
            def remove_background(
         
     | 
| 403 | 
         
            +
                image: PIL.Image.Image,
         
     | 
| 404 | 
         
            +
                rembg_session: Any = None,
         
     | 
| 405 | 
         
            +
                force: bool = False,
         
     | 
| 406 | 
         
            +
                **rembg_kwargs,
         
     | 
| 407 | 
         
            +
            ) -> PIL.Image.Image:
         
     | 
| 408 | 
         
            +
                do_remove = True
         
     | 
| 409 | 
         
            +
                if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
         
     | 
| 410 | 
         
            +
                    do_remove = False
         
     | 
| 411 | 
         
            +
                do_remove = do_remove or force
         
     | 
| 412 | 
         
            +
                if do_remove:
         
     | 
| 413 | 
         
            +
                    image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
         
     | 
| 414 | 
         
            +
                return image
         
     | 
| 415 | 
         
            +
             
     | 
| 416 | 
         
            +
             
     | 
| 417 | 
         
            +
            def resize_foreground(
         
     | 
| 418 | 
         
            +
                image: PIL.Image.Image,
         
     | 
| 419 | 
         
            +
                ratio: float,
         
     | 
| 420 | 
         
            +
            ) -> PIL.Image.Image:
         
     | 
| 421 | 
         
            +
                image = np.array(image)
         
     | 
| 422 | 
         
            +
                assert image.shape[-1] == 4
         
     | 
| 423 | 
         
            +
                alpha = np.where(image[..., 3] > 0)
         
     | 
| 424 | 
         
            +
                y1, y2, x1, x2 = (
         
     | 
| 425 | 
         
            +
                    alpha[0].min(),
         
     | 
| 426 | 
         
            +
                    alpha[0].max(),
         
     | 
| 427 | 
         
            +
                    alpha[1].min(),
         
     | 
| 428 | 
         
            +
                    alpha[1].max(),
         
     | 
| 429 | 
         
            +
                )
         
     | 
| 430 | 
         
            +
                # crop the foreground
         
     | 
| 431 | 
         
            +
                fg = image[y1:y2, x1:x2]
         
     | 
| 432 | 
         
            +
                # pad to square
         
     | 
| 433 | 
         
            +
                size = max(fg.shape[0], fg.shape[1])
         
     | 
| 434 | 
         
            +
                ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
         
     | 
| 435 | 
         
            +
                ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
         
     | 
| 436 | 
         
            +
                new_image = np.pad(
         
     | 
| 437 | 
         
            +
                    fg,
         
     | 
| 438 | 
         
            +
                    ((ph0, ph1), (pw0, pw1), (0, 0)),
         
     | 
| 439 | 
         
            +
                    mode="constant",
         
     | 
| 440 | 
         
            +
                    constant_values=((0, 0), (0, 0), (0, 0)),
         
     | 
| 441 | 
         
            +
                )
         
     | 
| 442 | 
         
            +
             
     | 
| 443 | 
         
            +
                # compute padding according to the ratio
         
     | 
| 444 | 
         
            +
                new_size = int(new_image.shape[0] / ratio)
         
     | 
| 445 | 
         
            +
                # pad to size, double side
         
     | 
| 446 | 
         
            +
                ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
         
     | 
| 447 | 
         
            +
                ph1, pw1 = new_size - size - ph0, new_size - size - pw0
         
     | 
| 448 | 
         
            +
                new_image = np.pad(
         
     | 
| 449 | 
         
            +
                    new_image,
         
     | 
| 450 | 
         
            +
                    ((ph0, ph1), (pw0, pw1), (0, 0)),
         
     | 
| 451 | 
         
            +
                    mode="constant",
         
     | 
| 452 | 
         
            +
                    constant_values=((0, 0), (0, 0), (0, 0)),
         
     | 
| 453 | 
         
            +
                )
         
     | 
| 454 | 
         
            +
                new_image = PIL.Image.fromarray(new_image)
         
     | 
| 455 | 
         
            +
                return new_image
         
     | 
| 456 | 
         
            +
             
     | 
| 457 | 
         
            +
             
     | 
| 458 | 
         
            +
            def save_video(
         
     | 
| 459 | 
         
            +
                frames: List[PIL.Image.Image],
         
     | 
| 460 | 
         
            +
                output_path: str,
         
     | 
| 461 | 
         
            +
                fps: int = 30,
         
     | 
| 462 | 
         
            +
            ):
         
     | 
| 463 | 
         
            +
                # use imageio to save video
         
     | 
| 464 | 
         
            +
                frames = [np.array(frame) for frame in frames]
         
     | 
| 465 | 
         
            +
                writer = imageio.get_writer(output_path, fps=fps)
         
     | 
| 466 | 
         
            +
                for frame in frames:
         
     | 
| 467 | 
         
            +
                    writer.append_data(frame)
         
     | 
| 468 | 
         
            +
                writer.close()
         
     | 
| 469 | 
         
            +
             
     | 
| 470 | 
         
            +
             
     | 
| 471 | 
         
            +
            def to_gradio_3d_orientation(mesh):
         
     | 
| 472 | 
         
            +
                mesh.apply_transform(trimesh.transformations.rotation_matrix(-np.pi/2, [1, 0, 0]))
         
     | 
| 473 | 
         
            +
                mesh.apply_transform(trimesh.transformations.rotation_matrix(np.pi/2, [0, 1, 0]))
         
     | 
| 474 | 
         
            +
                return mesh
         
     |