Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Commit 
							
							·
						
						6ba63c9
	
1
								Parent(s):
							
							cbd253a
								
Add initial module structure and entry points for modeling and utilities
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitattributes +1 -0
- Dockerfile +77 -0
- README.md +5 -5
- colabs/ENVIRONMENT.md +6 -0
- colabs/biomedparse_inference_demo.py +156 -0
- colabs/environment.yml +149 -0
- colabs/requirements-colab-pip-freeze.txt +567 -0
- colabs/requirements-colab.txt +39 -0
- configs/biomedparse_inference.yaml +204 -0
- entrypoint.sh +5 -0
- examples/Part_1_516_pathology_breast.png +3 -0
- inference_utils/inference.py +149 -0
- inference_utils/output_processing.py +91 -0
- inference_utils/processing_utils.py +182 -0
- inference_utils/target_dist.json +1 -0
- main.py +106 -0
- modeling/BaseModel.py +45 -0
- modeling/__init__.py +1 -0
- modeling/architectures/__init__.py +5 -0
- modeling/architectures/build.py +22 -0
- modeling/architectures/seem_model_demo.py +923 -0
- modeling/architectures/seem_model_v0.py +1160 -0
- modeling/architectures/seem_model_v1.py +1179 -0
- modeling/architectures/xdecoder_model.py +937 -0
- modeling/body/__init__.py +10 -0
- modeling/body/build.py +13 -0
- modeling/body/xdecoder_head.py +126 -0
- modeling/interface/__init__.py +13 -0
- modeling/interface/build.py +14 -0
- modeling/interface/modules.py +200 -0
- modeling/interface/prototype/__init__.py +0 -0
- modeling/interface/prototype/attention_data_struct_seemdemo.py +265 -0
- modeling/interface/prototype/attention_data_struct_seemv0.py +264 -0
- modeling/interface/prototype/attention_data_struct_seemv1.py +302 -0
- modeling/interface/seem_demo.py +397 -0
- modeling/interface/seem_v0.py +392 -0
- modeling/interface/seem_v1.py +389 -0
- modeling/interface/xdecoder.py +497 -0
- modeling/language/LangEncoder/__init__.py +35 -0
- modeling/language/LangEncoder/build.py +16 -0
- modeling/language/LangEncoder/transformer.py +222 -0
- modeling/language/__init__.py +10 -0
- modeling/language/build.py +14 -0
- modeling/language/loss.py +232 -0
- modeling/language/misc.py +66 -0
- modeling/language/vlpencoder.py +206 -0
- modeling/modules/__init__.py +6 -0
- modeling/modules/attention.py +487 -0
- modeling/modules/criterion.py +874 -0
- modeling/modules/matcher.py +632 -0
    	
        .gitattributes
    CHANGED
    
    | @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
|  | 
|  | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
| 36 | 
            +
            *.png filter=lfs diff=lfs merge=lfs -text
         | 
    	
        Dockerfile
    ADDED
    
    | @@ -0,0 +1,77 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
         | 
| 2 | 
            +
            # you will also find guides on how best to write your Dockerfile
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            FROM continuumio/miniconda3:latest
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # Add build argument to force rebuild
         | 
| 7 | 
            +
            ARG CACHEBUST=1
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # Avoid tzdata interactive configuration
         | 
| 10 | 
            +
            ENV DEBIAN_FRONTEND=noninteractive
         | 
| 11 | 
            +
            ENV TZ=UTC
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            # Install system dependencies
         | 
| 14 | 
            +
            RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
         | 
| 15 | 
            +
                git \
         | 
| 16 | 
            +
                build-essential \
         | 
| 17 | 
            +
                python3-dev \
         | 
| 18 | 
            +
                wget \
         | 
| 19 | 
            +
                openmpi-bin \
         | 
| 20 | 
            +
                libopenmpi-dev \
         | 
| 21 | 
            +
                libopenmpi3 \
         | 
| 22 | 
            +
                libhwloc15 \
         | 
| 23 | 
            +
                libevent-dev \
         | 
| 24 | 
            +
                libpmix2 \
         | 
| 25 | 
            +
                libgl1 \
         | 
| 26 | 
            +
                libglib2.0-0 \
         | 
| 27 | 
            +
                && rm -rf /var/lib/apt/lists/*
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            # Set up OpenMPI environment
         | 
| 30 | 
            +
            ENV OMPI_MCA_btl_vader_single_copy_mechanism=none \
         | 
| 31 | 
            +
                OMPI_ALLOW_RUN_AS_ROOT=1 \
         | 
| 32 | 
            +
                OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1 \
         | 
| 33 | 
            +
                PATH=/usr/lib/x86_64-linux-gnu/openmpi/bin:$PATH \
         | 
| 34 | 
            +
                LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/openmpi/lib:/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            # Copy environment file
         | 
| 37 | 
            +
            COPY colabs/environment.yml /tmp/environment.yml
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            # Create conda environment
         | 
| 40 | 
            +
            RUN conda env create -f /tmp/environment.yml && \
         | 
| 41 | 
            +
                conda run -n biomedparse pip install gradio==3.50.2
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            # Initialize conda in bash
         | 
| 44 | 
            +
            RUN conda init bash
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            # Make RUN commands use the new environment
         | 
| 47 | 
            +
            SHELL ["conda", "run", "-n", "biomedparse", "/bin/bash", "-c"]
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            # Set up a new user named "user" with user ID 1000
         | 
| 50 | 
            +
            RUN useradd -m -u 1000 user
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            # Switch to the "user" user
         | 
| 53 | 
            +
            USER user
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            # Set up HF token for the user
         | 
| 56 | 
            +
            RUN --mount=type=secret,id=HF_TOKEN,mode=0444,required=true \
         | 
| 57 | 
            +
                echo "export HF_TOKEN=$(cat /run/secrets/HF_TOKEN)" >> $HOME/.bashrc
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            # Set home to the user's home directory
         | 
| 60 | 
            +
            ENV HOME=/home/user \
         | 
| 61 | 
            +
                PATH=/home/user/.local/bin:$PATH
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            # Set the working directory to the user's home directory
         | 
| 64 | 
            +
            WORKDIR $HOME/app
         | 
| 65 | 
            +
             | 
| 66 | 
            +
            # Copy all files to the app directory
         | 
| 67 | 
            +
            COPY --chown=user . $HOME/app
         | 
| 68 | 
            +
             | 
| 69 | 
            +
            # Set permissions for entrypoint script
         | 
| 70 | 
            +
            RUN chmod 755 $HOME/app/entrypoint.sh
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            # Add conda environment to user's path
         | 
| 73 | 
            +
            RUN echo "conda activate biomedparse" >> $HOME/.bashrc
         | 
| 74 | 
            +
             | 
| 75 | 
            +
            # Use entrypoint script to set up environment and run application
         | 
| 76 | 
            +
            ENTRYPOINT ["/bin/bash", "-c"]
         | 
| 77 | 
            +
            CMD ["exec /home/user/app/entrypoint.sh"]
         | 
    	
        README.md
    CHANGED
    
    | @@ -1,11 +1,11 @@ | |
| 1 | 
             
            ---
         | 
| 2 | 
            -
            title:  | 
| 3 | 
            -
            emoji:  | 
| 4 | 
            -
            colorFrom:  | 
| 5 | 
            -
            colorTo:  | 
| 6 | 
             
            sdk: docker
         | 
| 7 | 
             
            pinned: false
         | 
| 8 | 
            -
             | 
| 9 | 
             
            ---
         | 
| 10 |  | 
| 11 | 
             
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
|  | |
| 1 | 
             
            ---
         | 
| 2 | 
            +
            title: Biomedparse Docker
         | 
| 3 | 
            +
            emoji: 📉
         | 
| 4 | 
            +
            colorFrom: yellow
         | 
| 5 | 
            +
            colorTo: blue
         | 
| 6 | 
             
            sdk: docker
         | 
| 7 | 
             
            pinned: false
         | 
| 8 | 
            +
            license: cc-by-nc-sa-4.0
         | 
| 9 | 
             
            ---
         | 
| 10 |  | 
| 11 | 
             
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
    	
        colabs/ENVIRONMENT.md
    ADDED
    
    | @@ -0,0 +1,6 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Description of Google Colab Environment
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            - Hardware: Python 3 Google Compute Engine Backend on T4 GPU
         | 
| 4 | 
            +
            - CUDA version: 12.2
         | 
| 5 | 
            +
            - Driver Version: 535.104.05
         | 
| 6 | 
            +
            - Python version: 3.10.12
         | 
    	
        colabs/biomedparse_inference_demo.py
    ADDED
    
    | @@ -0,0 +1,156 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 2 | 
            +
            """biomedparse_inference_demo.ipynb
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            Automatically generated by Colab.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            Original file is located at
         | 
| 7 | 
            +
                https://colab.research.google.com/drive/1jL4wvdtBWz6G_yBkFn8tyDD0hV1RtKVZ
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # BiomedParse Inference Demo Notebook
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            Welcome to the demo notebook for BiomedParse, a comprehensive tool for biomedical image analysis. BiomedParse is designed to simultaneously handle segmentation, detection, and recognition tasks across major biomedical image modalities, providing a unified solution for complex image analysis in biomedical research.
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            [[`Paper`](https://aka.ms/biomedparse-paper)] [[`Demo`](https://microsoft.github.io/BiomedParse/)] [[`Model`](https://huggingface.co/microsoft/BiomedParse)]  [[`Data`](https://huggingface.co/datasets/microsoft/BiomedParseData)]
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            ## Model Checkpoint Access
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            The BiomedParse model checkpoint is hosted on [HuggingFace](https://huggingface.co/microsoft/BiomedParse). To access the model:
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            1. Visit the [model page](https://huggingface.co/microsoft/BiomedParse).
         | 
| 20 | 
            +
            2. Make sure to review and accept the terms of use to gain access to the checkpoint.
         | 
| 21 | 
            +
            3. Retrieve your HuggingFace access token from your user profile.
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            ## Setting Up Access
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            To use the model, set your Hugging Face access token in the HF_TOKEN environment variable or as a Colab secret. This step ensures secure and authorized access to the model resources.
         | 
| 26 | 
            +
            """
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            # Set your Hugging Face access token in your environment
         | 
| 29 | 
            +
            # import os
         | 
| 30 | 
            +
            # os.environ['HF_TOKEN'] = 'your_huggingface_access_token_here'
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            # Or, if you are using Google Colab, set HF_TOKEN on Colab secrets.
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            from google.colab import userdata
         | 
| 35 | 
            +
            import huggingface_hub
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            huggingface_hub.login(userdata.get('HF_TOKEN'))
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            from huggingface_hub import hf_hub_download
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            model_file = hf_hub_download(repo_id="microsoft/BiomedParse", filename="biomedparse_v1.pt", local_dir="pretrained")
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            print(f"Downloaded model file to: {model_file}")
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            """## Environment Setup"""
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            !git clone https://github.com/microsoft/BiomedParse
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            !pip install -r BiomedParse/assets/requirements/requirements.txt
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            """# Restart Colab Runtime"""
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            # Make sure to restart Colab runtime after installing dependencies
         | 
| 54 | 
            +
            import os
         | 
| 55 | 
            +
            try:
         | 
| 56 | 
            +
                import google.colab
         | 
| 57 | 
            +
                os._exit(0)
         | 
| 58 | 
            +
            except ImportError:
         | 
| 59 | 
            +
                pass
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            import os
         | 
| 62 | 
            +
            os.chdir('/content/BiomedParse')
         | 
| 63 | 
            +
            print(os.getcwd())
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            """## Load the model weights"""
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            from PIL import Image
         | 
| 68 | 
            +
            import torch
         | 
| 69 | 
            +
            import argparse
         | 
| 70 | 
            +
            import numpy as np
         | 
| 71 | 
            +
            from modeling.BaseModel import BaseModel
         | 
| 72 | 
            +
            from modeling import build_model
         | 
| 73 | 
            +
            from utilities.distributed import init_distributed # changed from utils
         | 
| 74 | 
            +
            from utilities.arguments import load_opt_from_config_files
         | 
| 75 | 
            +
            from utilities.constants import BIOMED_CLASSES
         | 
| 76 | 
            +
            from inference_utils.inference import interactive_infer_image
         | 
| 77 | 
            +
             | 
| 78 | 
            +
            conf_files = "configs/biomedparse_inference.yaml"
         | 
| 79 | 
            +
            opt = load_opt_from_config_files([conf_files])
         | 
| 80 | 
            +
            opt = init_distributed(opt)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
            model_file = "../pretrained/biomedparse_v1.pt"
         | 
| 83 | 
            +
             | 
| 84 | 
            +
            model = BaseModel(opt, build_model(opt)).from_pretrained(model_file).eval().cuda()
         | 
| 85 | 
            +
            with torch.no_grad():
         | 
| 86 | 
            +
                model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(BIOMED_CLASSES + ["background"], is_eval=True)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
            """# Run Inference"""
         | 
| 89 | 
            +
             | 
| 90 | 
            +
            # RGB image input of shape (H, W, 3). Currently only batch size 1 is supported.
         | 
| 91 | 
            +
            image = Image.open('examples/Part_1_516_pathology_breast.png', formats=['png'])
         | 
| 92 | 
            +
            image = image.convert('RGB')
         | 
| 93 | 
            +
             | 
| 94 | 
            +
            # text prompts querying objects in the image. Multiple ones can be provided.
         | 
| 95 | 
            +
            prompts = ['neoplastic cells', 'inflammatory cells']
         | 
| 96 | 
            +
             | 
| 97 | 
            +
            pred_mask = interactive_infer_image(model, image, prompts)
         | 
| 98 | 
            +
            pred_mask.shape
         | 
| 99 | 
            +
             | 
| 100 | 
            +
            # load ground truth mask
         | 
| 101 | 
            +
            gt_masks = []
         | 
| 102 | 
            +
            for prompt in prompts:
         | 
| 103 | 
            +
                gt_mask = Image.open(f"examples/Part_1_516_pathology_breast_{prompt.replace(' ', '+')}.png", formats=['png'])
         | 
| 104 | 
            +
                gt_mask = 1*(np.array(gt_mask.convert('RGB'))[:,:,0] > 0)
         | 
| 105 | 
            +
                gt_masks.append(gt_mask)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
            # prediction with ground truth mask
         | 
| 108 | 
            +
            for i, pred in enumerate(pred_mask):
         | 
| 109 | 
            +
                gt = gt_masks[i]
         | 
| 110 | 
            +
                dice = (1*(pred>0.5) & gt).sum() * 2.0 / (1*(pred>0.5).sum() + gt.sum())
         | 
| 111 | 
            +
                print(f'Dice score for {prompts[i]}: {dice:.4f}')
         | 
| 112 | 
            +
             | 
| 113 | 
            +
            import numpy as np
         | 
| 114 | 
            +
            import matplotlib.pyplot as plt
         | 
| 115 | 
            +
            from PIL import Image
         | 
| 116 | 
            +
            import matplotlib.patches as mpatches
         | 
| 117 | 
            +
             | 
| 118 | 
            +
            def overlay_masks(image, masks, colors):
         | 
| 119 | 
            +
                overlay = image.copy()
         | 
| 120 | 
            +
                overlay = np.array(overlay, dtype=np.uint8)
         | 
| 121 | 
            +
                for mask, color in zip(masks, colors):
         | 
| 122 | 
            +
                    overlay[mask > 0] = (overlay[mask > 0] * 0.4 + np.array(color) * 0.6).astype(np.uint8)
         | 
| 123 | 
            +
                return Image.fromarray(overlay)
         | 
| 124 | 
            +
             | 
| 125 | 
            +
            def generate_colors(n):
         | 
| 126 | 
            +
                cmap = plt.get_cmap('tab10')
         | 
| 127 | 
            +
                colors = [tuple(int(255 * val) for val in cmap(i)[:3]) for i in range(n)]
         | 
| 128 | 
            +
                return colors
         | 
| 129 | 
            +
             | 
| 130 | 
            +
            original_image = Image.open('examples/Part_1_516_pathology_breast.png').convert('RGB')
         | 
| 131 | 
            +
             | 
| 132 | 
            +
            colors = generate_colors(len(prompts))
         | 
| 133 | 
            +
             | 
| 134 | 
            +
            pred_overlay = overlay_masks(original_image, [1*(pred_mask[i] > 0.5) for i in range(len(prompts))], colors)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
            gt_overlay = overlay_masks(original_image, gt_masks, colors)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
            legend_patches = [mpatches.Patch(color=np.array(color) / 255, label=prompt) for color, prompt in zip(colors, prompts)]
         | 
| 139 | 
            +
             | 
| 140 | 
            +
            fig, axes = plt.subplots(1, 3, figsize=(15, 5))
         | 
| 141 | 
            +
            axes[0].imshow(original_image)
         | 
| 142 | 
            +
            axes[0].set_title("Original Image")
         | 
| 143 | 
            +
            axes[0].axis('off')
         | 
| 144 | 
            +
             | 
| 145 | 
            +
            axes[1].imshow(pred_overlay)
         | 
| 146 | 
            +
            axes[1].set_title("Predictions")
         | 
| 147 | 
            +
            axes[1].axis('off')
         | 
| 148 | 
            +
            axes[1].legend(handles=legend_patches, loc='upper right', fontsize='small')
         | 
| 149 | 
            +
             | 
| 150 | 
            +
            axes[2].imshow(gt_overlay)
         | 
| 151 | 
            +
            axes[2].set_title("Ground Truth")
         | 
| 152 | 
            +
            axes[2].axis('off')
         | 
| 153 | 
            +
            axes[2].legend(handles=legend_patches, loc='upper right', fontsize='small')
         | 
| 154 | 
            +
             | 
| 155 | 
            +
            plt.tight_layout()
         | 
| 156 | 
            +
            plt.show()
         | 
    	
        colabs/environment.yml
    ADDED
    
    | @@ -0,0 +1,149 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            name: biomedparse
         | 
| 2 | 
            +
            channels:
         | 
| 3 | 
            +
              - pytorch
         | 
| 4 | 
            +
              - nvidia
         | 
| 5 | 
            +
              - defaults
         | 
| 6 | 
            +
            dependencies:
         | 
| 7 | 
            +
              - _libgcc_mutex=0.1=main
         | 
| 8 | 
            +
              - _openmp_mutex=5.1=1_gnu
         | 
| 9 | 
            +
              - blas=1.0=mkl
         | 
| 10 | 
            +
              - brotli-python=1.0.9=py39h6a678d5_8
         | 
| 11 | 
            +
              - bzip2=1.0.8=h5eee18b_6
         | 
| 12 | 
            +
              - ca-certificates=2024.7.2=h06a4308_0
         | 
| 13 | 
            +
              - certifi=2024.7.4=py39h06a4308_0
         | 
| 14 | 
            +
              - charset-normalizer=3.3.2=pyhd3eb1b0_0
         | 
| 15 | 
            +
              - cuda-cudart=12.4.127=0
         | 
| 16 | 
            +
              - cuda-cupti=12.4.127=0
         | 
| 17 | 
            +
              - cuda-libraries=12.4.0=0
         | 
| 18 | 
            +
              - cuda-nvrtc=12.4.127=0
         | 
| 19 | 
            +
              - cuda-nvtx=12.4.127=0
         | 
| 20 | 
            +
              - cuda-opencl=12.6.37=0
         | 
| 21 | 
            +
              - cuda-runtime=12.4.0=0
         | 
| 22 | 
            +
              - cuda-version=12.6=3
         | 
| 23 | 
            +
              - ffmpeg=4.3=hf484d3e_0
         | 
| 24 | 
            +
              - filelock=3.13.1=py39h06a4308_0
         | 
| 25 | 
            +
              - freetype=2.12.1=h4a9f257_0
         | 
| 26 | 
            +
              - gmp=6.2.1=h295c915_3
         | 
| 27 | 
            +
              - gmpy2=2.1.2=py39heeb90bb_0
         | 
| 28 | 
            +
              - gnutls=3.6.15=he1e5248_0
         | 
| 29 | 
            +
              - idna=3.7=py39h06a4308_0
         | 
| 30 | 
            +
              - intel-openmp=2023.1.0=hdb19cb5_46306
         | 
| 31 | 
            +
              - jinja2=3.1.4=py39h06a4308_0
         | 
| 32 | 
            +
              - jpeg=9e=h5eee18b_3
         | 
| 33 | 
            +
              - lame=3.100=h7b6447c_0
         | 
| 34 | 
            +
              - lcms2=2.12=h3be6417_0
         | 
| 35 | 
            +
              - ld_impl_linux-64=2.38=h1181459_1
         | 
| 36 | 
            +
              - lerc=3.0=h295c915_0
         | 
| 37 | 
            +
              - libcublas=12.4.2.65=0
         | 
| 38 | 
            +
              - libcufft=11.2.0.44=0
         | 
| 39 | 
            +
              - libcufile=1.11.0.15=0
         | 
| 40 | 
            +
              - libcurand=10.3.7.37=0
         | 
| 41 | 
            +
              - libcusolver=11.6.0.99=0
         | 
| 42 | 
            +
              - libcusparse=12.3.0.142=0
         | 
| 43 | 
            +
              - libdeflate=1.17=h5eee18b_1
         | 
| 44 | 
            +
              - libffi=3.4.4=h6a678d5_1
         | 
| 45 | 
            +
              - libgcc-ng=11.2.0=h1234567_1
         | 
| 46 | 
            +
              - libgomp=11.2.0=h1234567_1
         | 
| 47 | 
            +
              - libiconv=1.16=h5eee18b_3
         | 
| 48 | 
            +
              - libidn2=2.3.4=h5eee18b_0
         | 
| 49 | 
            +
              - libjpeg-turbo=2.0.0=h9bf148f_0
         | 
| 50 | 
            +
              - libnpp=12.2.5.2=0
         | 
| 51 | 
            +
              - libnvfatbin=12.6.20=0
         | 
| 52 | 
            +
              - libnvjitlink=12.4.99=0
         | 
| 53 | 
            +
              - libnvjpeg=12.3.1.89=0
         | 
| 54 | 
            +
              - libpng=1.6.39=h5eee18b_0
         | 
| 55 | 
            +
              - libstdcxx-ng=11.2.0=h1234567_1
         | 
| 56 | 
            +
              - libtasn1=4.19.0=h5eee18b_0
         | 
| 57 | 
            +
              - libtiff=4.5.1=h6a678d5_0
         | 
| 58 | 
            +
              - libunistring=0.9.10=h27cfd23_0
         | 
| 59 | 
            +
              - libwebp-base=1.3.2=h5eee18b_0
         | 
| 60 | 
            +
              - llvm-openmp=14.0.6=h9e868ea_0
         | 
| 61 | 
            +
              - lz4-c=1.9.4=h6a678d5_1
         | 
| 62 | 
            +
              - markupsafe=2.1.3=py39h5eee18b_0
         | 
| 63 | 
            +
              - mkl=2023.1.0=h213fc3f_46344
         | 
| 64 | 
            +
              - mkl-service=2.4.0=py39h5eee18b_1
         | 
| 65 | 
            +
              - mkl_fft=1.3.8=py39h5eee18b_0
         | 
| 66 | 
            +
              - mkl_random=1.2.4=py39hdb19cb5_0
         | 
| 67 | 
            +
              - mpc=1.1.0=h10f8cd9_1
         | 
| 68 | 
            +
              - mpfr=4.0.2=hb69a4c5_1
         | 
| 69 | 
            +
              - mpmath=1.3.0=py39h06a4308_0
         | 
| 70 | 
            +
              - ncurses=6.4=h6a678d5_0
         | 
| 71 | 
            +
              - nettle=3.7.3=hbbd107a_1
         | 
| 72 | 
            +
              - networkx=3.2.1=py39h06a4308_0
         | 
| 73 | 
            +
              - openh264=2.1.1=h4ff587b_0
         | 
| 74 | 
            +
              - openjpeg=2.5.2=he7f1fd0_0
         | 
| 75 | 
            +
              - openssl=3.0.14=h5eee18b_0
         | 
| 76 | 
            +
              - pip=24.2=py39h06a4308_0
         | 
| 77 | 
            +
              - pysocks=1.7.1=py39h06a4308_0
         | 
| 78 | 
            +
              - python=3.9.19=h955ad1f_1
         | 
| 79 | 
            +
              - pytorch=2.4.0=py3.9_cuda12.4_cudnn9.1.0_0
         | 
| 80 | 
            +
              - pytorch-cuda=12.4=hc786d27_6
         | 
| 81 | 
            +
              - pytorch-mutex=1.0=cuda
         | 
| 82 | 
            +
              - pyyaml=6.0.1=py39h5eee18b_0
         | 
| 83 | 
            +
              - readline=8.2=h5eee18b_0
         | 
| 84 | 
            +
              - requests=2.32.3=py39h06a4308_0
         | 
| 85 | 
            +
              - setuptools=72.1.0=py39h06a4308_0
         | 
| 86 | 
            +
              - sqlite=3.45.3=h5eee18b_0
         | 
| 87 | 
            +
              - sympy=1.12=py39h06a4308_0
         | 
| 88 | 
            +
              - tbb=2021.8.0=hdb19cb5_0
         | 
| 89 | 
            +
              - tk=8.6.14=h39e8969_0
         | 
| 90 | 
            +
              - torchaudio=2.4.0=py39_cu124
         | 
| 91 | 
            +
              - torchtriton=3.0.0=py39
         | 
| 92 | 
            +
              - torchvision=0.19.0=py39_cu124
         | 
| 93 | 
            +
              - typing_extensions=4.11.0=py39h06a4308_0
         | 
| 94 | 
            +
              - tzdata=2024a=h04d1e81_0
         | 
| 95 | 
            +
              - urllib3=2.2.2=py39h06a4308_0
         | 
| 96 | 
            +
              - wheel=0.43.0=py39h06a4308_0
         | 
| 97 | 
            +
              - xz=5.4.6=h5eee18b_1
         | 
| 98 | 
            +
              - yaml=0.2.5=h7b6447c_0
         | 
| 99 | 
            +
              - zlib=1.2.13=h5eee18b_1
         | 
| 100 | 
            +
              - zstd=1.5.5=hc292b87_2
         | 
| 101 | 
            +
              - pip:
         | 
| 102 | 
            +
                  - accelerate==0.23.0
         | 
| 103 | 
            +
                  - antlr4-python3-runtime==4.9.3
         | 
| 104 | 
            +
                  - appdirs==1.4.4
         | 
| 105 | 
            +
                  - black==21.4b2
         | 
| 106 | 
            +
                  - open-clip-torch==2.26.1
         | 
| 107 | 
            +
                  - cloudpickle==3.0.0
         | 
| 108 | 
            +
                  - cython==3.0.2
         | 
| 109 | 
            +
                  - deepspeed==0.10.3
         | 
| 110 | 
            +
                  - git+https://github.com/MaureenZOU/detectron2-xyz.git
         | 
| 111 | 
            +
                  - diffdist==0.1
         | 
| 112 | 
            +
                  - einops==0.8.0
         | 
| 113 | 
            +
                  - ftfy==6.1.1
         | 
| 114 | 
            +
                  - fvcore==0.1.5.post20221221
         | 
| 115 | 
            +
                  - hjson==3.1.0
         | 
| 116 | 
            +
                  - huggingface-hub==0.17.3
         | 
| 117 | 
            +
                  - hydra-core==1.3.2
         | 
| 118 | 
            +
                  - imageio==2.35.1
         | 
| 119 | 
            +
                  - infinibatch==0.1.1
         | 
| 120 | 
            +
                  - iopath==0.1.9
         | 
| 121 | 
            +
                  - json-tricks==3.17.3
         | 
| 122 | 
            +
                  - kornia==0.7.0
         | 
| 123 | 
            +
                  - mpi4py==3.1.5
         | 
| 124 | 
            +
                  - mup==1.0.0
         | 
| 125 | 
            +
                  - mypy-extensions==1.0.0
         | 
| 126 | 
            +
                  - ninja==1.11.1.1
         | 
| 127 | 
            +
                  - nltk==3.8.1
         | 
| 128 | 
            +
                  - numpy==1.23.1
         | 
| 129 | 
            +
                  - omegaconf==2.3.0
         | 
| 130 | 
            +
                  - opencv-python==4.8.1.78
         | 
| 131 | 
            +
                  - pandas==2.0.3
         | 
| 132 | 
            +
                  - pathspec==0.12.1
         | 
| 133 | 
            +
                  - pillow==9.4.0
         | 
| 134 | 
            +
                  - portalocker==2.10.1
         | 
| 135 | 
            +
                  - py-cpuinfo==9.0.0
         | 
| 136 | 
            +
                  - pycocotools==2.0.7
         | 
| 137 | 
            +
                  - pydantic==1.10.18
         | 
| 138 | 
            +
                  - pydot==3.0.1
         | 
| 139 | 
            +
                  - regex==2023.10.3
         | 
| 140 | 
            +
                  - scikit-image==0.21.0
         | 
| 141 | 
            +
                  - scikit-learn==1.3.1
         | 
| 142 | 
            +
                  - sentencepiece==0.1.99
         | 
| 143 | 
            +
                  - tabulate==0.9.0
         | 
| 144 | 
            +
                  - termcolor==2.4.0
         | 
| 145 | 
            +
                  - timm==0.4.12
         | 
| 146 | 
            +
                  - tokenizers==0.14.1
         | 
| 147 | 
            +
                  - transformers==4.34.0
         | 
| 148 | 
            +
                  - vision-datasets==0.2.2
         | 
| 149 | 
            +
                  - yacs==0.1.8
         | 
    	
        colabs/requirements-colab-pip-freeze.txt
    ADDED
    
    | @@ -0,0 +1,567 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            absl-py==1.4.0
         | 
| 2 | 
            +
            accelerate==0.23.0
         | 
| 3 | 
            +
            aiohappyeyeballs==2.4.4
         | 
| 4 | 
            +
            aiohttp==3.11.10
         | 
| 5 | 
            +
            aiosignal==1.3.2
         | 
| 6 | 
            +
            alabaster==1.0.0
         | 
| 7 | 
            +
            albucore==0.0.19
         | 
| 8 | 
            +
            albumentations==1.4.20
         | 
| 9 | 
            +
            altair==5.5.0
         | 
| 10 | 
            +
            annotated-types==0.7.0
         | 
| 11 | 
            +
            antlr4-python3-runtime==4.9.3
         | 
| 12 | 
            +
            anyio==3.7.1
         | 
| 13 | 
            +
            appdirs==1.4.4
         | 
| 14 | 
            +
            argon2-cffi==23.1.0
         | 
| 15 | 
            +
            argon2-cffi-bindings==21.2.0
         | 
| 16 | 
            +
            array_record==0.5.1
         | 
| 17 | 
            +
            arviz==0.20.0
         | 
| 18 | 
            +
            astropy==6.1.7
         | 
| 19 | 
            +
            astropy-iers-data==0.2024.12.16.0.35.48
         | 
| 20 | 
            +
            astunparse==1.6.3
         | 
| 21 | 
            +
            async-timeout==4.0.3
         | 
| 22 | 
            +
            atpublic==4.1.0
         | 
| 23 | 
            +
            attrs==24.3.0
         | 
| 24 | 
            +
            audioread==3.0.1
         | 
| 25 | 
            +
            autograd==1.7.0
         | 
| 26 | 
            +
            babel==2.16.0
         | 
| 27 | 
            +
            backcall==0.2.0
         | 
| 28 | 
            +
            beautifulsoup4==4.12.3
         | 
| 29 | 
            +
            bigframes==1.29.0
         | 
| 30 | 
            +
            bigquery-magics==0.4.0
         | 
| 31 | 
            +
            black==21.4b2
         | 
| 32 | 
            +
            bleach==6.2.0
         | 
| 33 | 
            +
            blinker==1.9.0
         | 
| 34 | 
            +
            blis==0.7.11
         | 
| 35 | 
            +
            blosc2==2.7.1
         | 
| 36 | 
            +
            bokeh==3.6.2
         | 
| 37 | 
            +
            Bottleneck==1.4.2
         | 
| 38 | 
            +
            bqplot==0.12.43
         | 
| 39 | 
            +
            branca==0.8.1
         | 
| 40 | 
            +
            CacheControl==0.14.1
         | 
| 41 | 
            +
            cachetools==5.5.0
         | 
| 42 | 
            +
            catalogue==2.0.10
         | 
| 43 | 
            +
            certifi==2024.12.14
         | 
| 44 | 
            +
            cffi==1.17.1
         | 
| 45 | 
            +
            chardet==5.2.0
         | 
| 46 | 
            +
            charset-normalizer==3.4.0
         | 
| 47 | 
            +
            chex==0.1.88
         | 
| 48 | 
            +
            clarabel==0.9.0
         | 
| 49 | 
            +
            click==8.1.7
         | 
| 50 | 
            +
            cloudpathlib==0.20.0
         | 
| 51 | 
            +
            cloudpickle==3.1.0
         | 
| 52 | 
            +
            cmake==3.31.2
         | 
| 53 | 
            +
            cmdstanpy==1.2.5
         | 
| 54 | 
            +
            colorcet==3.1.0
         | 
| 55 | 
            +
            colorlover==0.3.0
         | 
| 56 | 
            +
            colour==0.1.5
         | 
| 57 | 
            +
            community==1.0.0b1
         | 
| 58 | 
            +
            confection==0.1.5
         | 
| 59 | 
            +
            cons==0.4.6
         | 
| 60 | 
            +
            contourpy==1.3.1
         | 
| 61 | 
            +
            cryptography==43.0.3
         | 
| 62 | 
            +
            cuda-python==12.2.1
         | 
| 63 | 
            +
            cudf-cu12 @ https://pypi.nvidia.com/cudf-cu12/cudf_cu12-24.10.1-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
         | 
| 64 | 
            +
            cufflinks==0.17.3
         | 
| 65 | 
            +
            cupy-cuda12x==12.2.0
         | 
| 66 | 
            +
            cvxopt==1.3.2
         | 
| 67 | 
            +
            cvxpy==1.6.0
         | 
| 68 | 
            +
            cycler==0.12.1
         | 
| 69 | 
            +
            cymem==2.0.10
         | 
| 70 | 
            +
            Cython==3.0.2
         | 
| 71 | 
            +
            dask==2024.10.0
         | 
| 72 | 
            +
            datascience==0.17.6
         | 
| 73 | 
            +
            db-dtypes==1.3.1
         | 
| 74 | 
            +
            dbus-python==1.2.18
         | 
| 75 | 
            +
            debugpy==1.8.0
         | 
| 76 | 
            +
            decorator==4.4.2
         | 
| 77 | 
            +
            deepspeed==0.10.3
         | 
| 78 | 
            +
            defusedxml==0.7.1
         | 
| 79 | 
            +
            Deprecated==1.2.15
         | 
| 80 | 
            +
            detectron2 @ git+https://github.com/MaureenZOU/detectron2-xyz.git@42121d75e10d9f858f3a91b6a39f5722c02868f0
         | 
| 81 | 
            +
            diffdist==0.1
         | 
| 82 | 
            +
            diffusers==0.31.0
         | 
| 83 | 
            +
            distro==1.9.0
         | 
| 84 | 
            +
            dlib==19.24.2
         | 
| 85 | 
            +
            dm-tree==0.1.8
         | 
| 86 | 
            +
            docker-pycreds==0.4.0
         | 
| 87 | 
            +
            docstring_parser==0.16
         | 
| 88 | 
            +
            docutils==0.21.2
         | 
| 89 | 
            +
            dopamine_rl==4.1.0
         | 
| 90 | 
            +
            duckdb==1.1.3
         | 
| 91 | 
            +
            earthengine-api==1.4.3
         | 
| 92 | 
            +
            easydict==1.13
         | 
| 93 | 
            +
            editdistance==0.8.1
         | 
| 94 | 
            +
            eerepr==0.0.4
         | 
| 95 | 
            +
            einops==0.8.0
         | 
| 96 | 
            +
            en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl#sha256=86cc141f63942d4b2c5fcee06630fd6f904788d2f0ab005cce45aadb8fb73889
         | 
| 97 | 
            +
            entrypoints==0.4
         | 
| 98 | 
            +
            et_xmlfile==2.0.0
         | 
| 99 | 
            +
            etils==1.11.0
         | 
| 100 | 
            +
            etuples==0.3.9
         | 
| 101 | 
            +
            eval_type_backport==0.2.0
         | 
| 102 | 
            +
            exceptiongroup==1.2.2
         | 
| 103 | 
            +
            fastai==2.7.18
         | 
| 104 | 
            +
            fastcore==1.7.27
         | 
| 105 | 
            +
            fastdownload==0.0.7
         | 
| 106 | 
            +
            fastjsonschema==2.21.1
         | 
| 107 | 
            +
            fastprogress==1.0.3
         | 
| 108 | 
            +
            fastrlock==0.8.3
         | 
| 109 | 
            +
            filelock==3.16.1
         | 
| 110 | 
            +
            firebase-admin==6.6.0
         | 
| 111 | 
            +
            Flask==3.1.0
         | 
| 112 | 
            +
            flatbuffers==24.3.25
         | 
| 113 | 
            +
            flax==0.8.5
         | 
| 114 | 
            +
            folium==0.19.2
         | 
| 115 | 
            +
            fonttools==4.55.3
         | 
| 116 | 
            +
            frozendict==2.4.6
         | 
| 117 | 
            +
            frozenlist==1.5.0
         | 
| 118 | 
            +
            fsspec==2024.10.0
         | 
| 119 | 
            +
            ftfy==6.1.1
         | 
| 120 | 
            +
            future==1.0.0
         | 
| 121 | 
            +
            fvcore==0.1.5.post20221221
         | 
| 122 | 
            +
            gast==0.6.0
         | 
| 123 | 
            +
            gcsfs==2024.10.0
         | 
| 124 | 
            +
            GDAL==3.6.4
         | 
| 125 | 
            +
            gdown==5.2.0
         | 
| 126 | 
            +
            geemap==0.35.1
         | 
| 127 | 
            +
            gensim==4.3.3
         | 
| 128 | 
            +
            geocoder==1.38.1
         | 
| 129 | 
            +
            geographiclib==2.0
         | 
| 130 | 
            +
            geopandas==1.0.1
         | 
| 131 | 
            +
            geopy==2.4.1
         | 
| 132 | 
            +
            gin-config==0.5.0
         | 
| 133 | 
            +
            gitdb==4.0.11
         | 
| 134 | 
            +
            GitPython==3.1.43
         | 
| 135 | 
            +
            glob2==0.7
         | 
| 136 | 
            +
            google==2.0.3
         | 
| 137 | 
            +
            google-ai-generativelanguage==0.6.10
         | 
| 138 | 
            +
            google-api-core==2.19.2
         | 
| 139 | 
            +
            google-api-python-client==2.155.0
         | 
| 140 | 
            +
            google-auth==2.27.0
         | 
| 141 | 
            +
            google-auth-httplib2==0.2.0
         | 
| 142 | 
            +
            google-auth-oauthlib==1.2.1
         | 
| 143 | 
            +
            google-cloud-aiplatform==1.74.0
         | 
| 144 | 
            +
            google-cloud-bigquery==3.25.0
         | 
| 145 | 
            +
            google-cloud-bigquery-connection==1.17.0
         | 
| 146 | 
            +
            google-cloud-bigquery-storage==2.27.0
         | 
| 147 | 
            +
            google-cloud-bigtable==2.27.0
         | 
| 148 | 
            +
            google-cloud-core==2.4.1
         | 
| 149 | 
            +
            google-cloud-datastore==2.20.2
         | 
| 150 | 
            +
            google-cloud-firestore==2.19.0
         | 
| 151 | 
            +
            google-cloud-functions==1.19.0
         | 
| 152 | 
            +
            google-cloud-iam==2.17.0
         | 
| 153 | 
            +
            google-cloud-language==2.16.0
         | 
| 154 | 
            +
            google-cloud-pubsub==2.27.1
         | 
| 155 | 
            +
            google-cloud-resource-manager==1.14.0
         | 
| 156 | 
            +
            google-cloud-storage==2.19.0
         | 
| 157 | 
            +
            google-cloud-translate==3.19.0
         | 
| 158 | 
            +
            google-colab @ file:///colabtools/dist/google_colab-1.0.0.tar.gz
         | 
| 159 | 
            +
            google-crc32c==1.6.0
         | 
| 160 | 
            +
            google-genai==0.3.0
         | 
| 161 | 
            +
            google-generativeai==0.8.3
         | 
| 162 | 
            +
            google-pasta==0.2.0
         | 
| 163 | 
            +
            google-resumable-media==2.7.2
         | 
| 164 | 
            +
            googleapis-common-protos==1.66.0
         | 
| 165 | 
            +
            googledrivedownloader==0.4
         | 
| 166 | 
            +
            graphviz==0.20.3
         | 
| 167 | 
            +
            greenlet==3.1.1
         | 
| 168 | 
            +
            grpc-google-iam-v1==0.13.1
         | 
| 169 | 
            +
            grpcio==1.68.1
         | 
| 170 | 
            +
            grpcio-status==1.62.3
         | 
| 171 | 
            +
            gspread==6.0.2
         | 
| 172 | 
            +
            gspread-dataframe==3.3.1
         | 
| 173 | 
            +
            gym==0.25.2
         | 
| 174 | 
            +
            gym-notices==0.0.8
         | 
| 175 | 
            +
            h11==0.14.0
         | 
| 176 | 
            +
            h5netcdf==1.4.1
         | 
| 177 | 
            +
            h5py==3.12.1
         | 
| 178 | 
            +
            hjson==3.1.0
         | 
| 179 | 
            +
            holidays==0.63
         | 
| 180 | 
            +
            holoviews==1.20.0
         | 
| 181 | 
            +
            html5lib==1.1
         | 
| 182 | 
            +
            httpcore==1.0.7
         | 
| 183 | 
            +
            httpimport==1.4.0
         | 
| 184 | 
            +
            httplib2==0.22.0
         | 
| 185 | 
            +
            httpx==0.28.1
         | 
| 186 | 
            +
            huggingface-hub==0.17.3
         | 
| 187 | 
            +
            humanize==4.11.0
         | 
| 188 | 
            +
            hydra-core==1.3.2
         | 
| 189 | 
            +
            hyperopt==0.2.7
         | 
| 190 | 
            +
            ibis-framework==9.2.0
         | 
| 191 | 
            +
            idna==3.10
         | 
| 192 | 
            +
            imageio==2.36.1
         | 
| 193 | 
            +
            imageio-ffmpeg==0.5.1
         | 
| 194 | 
            +
            imagesize==1.4.1
         | 
| 195 | 
            +
            imbalanced-learn==0.12.4
         | 
| 196 | 
            +
            imgaug==0.4.0
         | 
| 197 | 
            +
            immutabledict==4.2.1
         | 
| 198 | 
            +
            importlib_metadata==8.5.0
         | 
| 199 | 
            +
            importlib_resources==6.4.5
         | 
| 200 | 
            +
            imutils==0.5.4
         | 
| 201 | 
            +
            infinibatch==0.1.1
         | 
| 202 | 
            +
            inflect==7.4.0
         | 
| 203 | 
            +
            iniconfig==2.0.0
         | 
| 204 | 
            +
            intel-cmplr-lib-ur==2025.0.4
         | 
| 205 | 
            +
            intel-openmp==2025.0.4
         | 
| 206 | 
            +
            iopath==0.1.9
         | 
| 207 | 
            +
            ipyevents==2.0.2
         | 
| 208 | 
            +
            ipyfilechooser==0.6.0
         | 
| 209 | 
            +
            ipykernel==5.5.6
         | 
| 210 | 
            +
            ipyleaflet==0.19.2
         | 
| 211 | 
            +
            ipyparallel==8.8.0
         | 
| 212 | 
            +
            ipython==7.34.0
         | 
| 213 | 
            +
            ipython-genutils==0.2.0
         | 
| 214 | 
            +
            ipython-sql==0.5.0
         | 
| 215 | 
            +
            ipytree==0.2.2
         | 
| 216 | 
            +
            ipywidgets==7.7.1
         | 
| 217 | 
            +
            itsdangerous==2.2.0
         | 
| 218 | 
            +
            jax==0.4.33
         | 
| 219 | 
            +
            jax-cuda12-pjrt==0.4.33
         | 
| 220 | 
            +
            jax-cuda12-plugin==0.4.33
         | 
| 221 | 
            +
            jaxlib==0.4.33
         | 
| 222 | 
            +
            jeepney==0.7.1
         | 
| 223 | 
            +
            jellyfish==1.1.0
         | 
| 224 | 
            +
            jieba==0.42.1
         | 
| 225 | 
            +
            Jinja2==3.1.4
         | 
| 226 | 
            +
            jiter==0.8.2
         | 
| 227 | 
            +
            joblib==1.4.2
         | 
| 228 | 
            +
            json-tricks==3.17.3
         | 
| 229 | 
            +
            jsonpatch==1.33
         | 
| 230 | 
            +
            jsonpickle==4.0.1
         | 
| 231 | 
            +
            jsonpointer==3.0.0
         | 
| 232 | 
            +
            jsonschema==4.23.0
         | 
| 233 | 
            +
            jsonschema-specifications==2024.10.1
         | 
| 234 | 
            +
            jupyter-client==6.1.12
         | 
| 235 | 
            +
            jupyter-console==6.1.0
         | 
| 236 | 
            +
            jupyter-leaflet==0.19.2
         | 
| 237 | 
            +
            jupyter-server==1.24.0
         | 
| 238 | 
            +
            jupyter_core==5.7.2
         | 
| 239 | 
            +
            jupyterlab_pygments==0.3.0
         | 
| 240 | 
            +
            jupyterlab_widgets==3.0.13
         | 
| 241 | 
            +
            kaggle==1.6.17
         | 
| 242 | 
            +
            kagglehub==0.3.5
         | 
| 243 | 
            +
            keras==3.5.0
         | 
| 244 | 
            +
            keyring==23.5.0
         | 
| 245 | 
            +
            kiwisolver==1.4.7
         | 
| 246 | 
            +
            kornia==0.7.0
         | 
| 247 | 
            +
            langchain==0.3.12
         | 
| 248 | 
            +
            langchain-core==0.3.25
         | 
| 249 | 
            +
            langchain-text-splitters==0.3.3
         | 
| 250 | 
            +
            langcodes==3.5.0
         | 
| 251 | 
            +
            langsmith==0.2.3
         | 
| 252 | 
            +
            language_data==1.3.0
         | 
| 253 | 
            +
            launchpadlib==1.10.16
         | 
| 254 | 
            +
            lazr.restfulclient==0.14.4
         | 
| 255 | 
            +
            lazr.uri==1.0.6
         | 
| 256 | 
            +
            lazy_loader==0.4
         | 
| 257 | 
            +
            libclang==18.1.1
         | 
| 258 | 
            +
            libcudf-cu12 @ https://pypi.nvidia.com/libcudf-cu12/libcudf_cu12-24.10.1-py3-none-manylinux_2_28_x86_64.whl
         | 
| 259 | 
            +
            librosa==0.10.2.post1
         | 
| 260 | 
            +
            lightgbm==4.5.0
         | 
| 261 | 
            +
            linkify-it-py==2.0.3
         | 
| 262 | 
            +
            llvmlite==0.43.0
         | 
| 263 | 
            +
            locket==1.0.0
         | 
| 264 | 
            +
            logical-unification==0.4.6
         | 
| 265 | 
            +
            lxml==5.3.0
         | 
| 266 | 
            +
            marisa-trie==1.2.1
         | 
| 267 | 
            +
            Markdown==3.7
         | 
| 268 | 
            +
            markdown-it-py==3.0.0
         | 
| 269 | 
            +
            MarkupSafe==3.0.2
         | 
| 270 | 
            +
            matplotlib==3.8.0
         | 
| 271 | 
            +
            matplotlib-inline==0.1.7
         | 
| 272 | 
            +
            matplotlib-venn==1.1.1
         | 
| 273 | 
            +
            mdit-py-plugins==0.4.2
         | 
| 274 | 
            +
            mdurl==0.1.2
         | 
| 275 | 
            +
            miniKanren==1.0.3
         | 
| 276 | 
            +
            missingno==0.5.2
         | 
| 277 | 
            +
            mistune==3.0.2
         | 
| 278 | 
            +
            mizani==0.13.1
         | 
| 279 | 
            +
            mkl==2025.0.1
         | 
| 280 | 
            +
            ml-dtypes==0.4.1
         | 
| 281 | 
            +
            mlxtend==0.23.3
         | 
| 282 | 
            +
            more-itertools==10.5.0
         | 
| 283 | 
            +
            moviepy==1.0.3
         | 
| 284 | 
            +
            mpi4py==3.1.5
         | 
| 285 | 
            +
            mpmath==1.3.0
         | 
| 286 | 
            +
            msgpack==1.1.0
         | 
| 287 | 
            +
            multidict==6.1.0
         | 
| 288 | 
            +
            multipledispatch==1.0.0
         | 
| 289 | 
            +
            multitasking==0.0.11
         | 
| 290 | 
            +
            mup==1.0.0
         | 
| 291 | 
            +
            murmurhash==1.0.11
         | 
| 292 | 
            +
            music21==9.3.0
         | 
| 293 | 
            +
            mypy-extensions==1.0.0
         | 
| 294 | 
            +
            namex==0.0.8
         | 
| 295 | 
            +
            narwhals==1.18.4
         | 
| 296 | 
            +
            natsort==8.4.0
         | 
| 297 | 
            +
            nbclassic==1.1.0
         | 
| 298 | 
            +
            nbclient==0.10.1
         | 
| 299 | 
            +
            nbconvert==7.16.4
         | 
| 300 | 
            +
            nbformat==5.10.4
         | 
| 301 | 
            +
            ndindex==1.9.2
         | 
| 302 | 
            +
            nest-asyncio==1.6.0
         | 
| 303 | 
            +
            networkx==3.4.2
         | 
| 304 | 
            +
            nibabel==5.3.2
         | 
| 305 | 
            +
            ninja==1.11.1.3
         | 
| 306 | 
            +
            nltk==3.8.1
         | 
| 307 | 
            +
            notebook==6.5.5
         | 
| 308 | 
            +
            notebook_shim==0.2.4
         | 
| 309 | 
            +
            numba==0.60.0
         | 
| 310 | 
            +
            numexpr==2.10.2
         | 
| 311 | 
            +
            numpy==1.26.4
         | 
| 312 | 
            +
            nvidia-cublas-cu12==12.6.4.1
         | 
| 313 | 
            +
            nvidia-cuda-cupti-cu12==12.6.80
         | 
| 314 | 
            +
            nvidia-cuda-nvcc-cu12==12.6.85
         | 
| 315 | 
            +
            nvidia-cuda-runtime-cu12==12.6.77
         | 
| 316 | 
            +
            nvidia-cudnn-cu12==9.6.0.74
         | 
| 317 | 
            +
            nvidia-cufft-cu12==11.3.0.4
         | 
| 318 | 
            +
            nvidia-curand-cu12==10.3.7.77
         | 
| 319 | 
            +
            nvidia-cusolver-cu12==11.7.1.2
         | 
| 320 | 
            +
            nvidia-cusparse-cu12==12.5.4.2
         | 
| 321 | 
            +
            nvidia-nccl-cu12==2.23.4
         | 
| 322 | 
            +
            nvidia-nvjitlink-cu12==12.6.85
         | 
| 323 | 
            +
            nvtx==0.2.10
         | 
| 324 | 
            +
            nx-cugraph-cu12 @ https://pypi.nvidia.com/nx-cugraph-cu12/nx_cugraph_cu12-24.10.0-py3-none-any.whl
         | 
| 325 | 
            +
            oauth2client==4.1.3
         | 
| 326 | 
            +
            oauthlib==3.2.2
         | 
| 327 | 
            +
            omegaconf==2.3.0
         | 
| 328 | 
            +
            open_clip_torch==2.26.1
         | 
| 329 | 
            +
            openai==1.57.4
         | 
| 330 | 
            +
            opencv-contrib-python==4.10.0.84
         | 
| 331 | 
            +
            opencv-python==4.8.1.78
         | 
| 332 | 
            +
            opencv-python-headless==4.10.0.84
         | 
| 333 | 
            +
            openpyxl==3.1.5
         | 
| 334 | 
            +
            opentelemetry-api==1.29.0
         | 
| 335 | 
            +
            opentelemetry-sdk==1.29.0
         | 
| 336 | 
            +
            opentelemetry-semantic-conventions==0.50b0
         | 
| 337 | 
            +
            opt_einsum==3.4.0
         | 
| 338 | 
            +
            optax==0.2.4
         | 
| 339 | 
            +
            optree==0.13.1
         | 
| 340 | 
            +
            orbax-checkpoint==0.6.4
         | 
| 341 | 
            +
            orjson==3.10.12
         | 
| 342 | 
            +
            osqp==0.6.7.post3
         | 
| 343 | 
            +
            packaging==24.2
         | 
| 344 | 
            +
            pandas==2.0.3
         | 
| 345 | 
            +
            pandas-datareader==0.10.0
         | 
| 346 | 
            +
            pandas-gbq==0.25.0
         | 
| 347 | 
            +
            pandas-stubs==2.2.2.240909
         | 
| 348 | 
            +
            pandocfilters==1.5.1
         | 
| 349 | 
            +
            panel==1.5.4
         | 
| 350 | 
            +
            param==2.2.0
         | 
| 351 | 
            +
            parso==0.8.4
         | 
| 352 | 
            +
            parsy==2.1
         | 
| 353 | 
            +
            partd==1.4.2
         | 
| 354 | 
            +
            pathlib==1.0.1
         | 
| 355 | 
            +
            pathspec==0.12.1
         | 
| 356 | 
            +
            patsy==1.0.1
         | 
| 357 | 
            +
            peewee==3.17.8
         | 
| 358 | 
            +
            peft==0.14.0
         | 
| 359 | 
            +
            pexpect==4.9.0
         | 
| 360 | 
            +
            pickleshare==0.7.5
         | 
| 361 | 
            +
            Pillow==9.4.0
         | 
| 362 | 
            +
            platformdirs==4.3.6
         | 
| 363 | 
            +
            plotly==5.24.1
         | 
| 364 | 
            +
            plotnine==0.14.4
         | 
| 365 | 
            +
            pluggy==1.5.0
         | 
| 366 | 
            +
            ply==3.11
         | 
| 367 | 
            +
            polars==1.9.0
         | 
| 368 | 
            +
            pooch==1.8.2
         | 
| 369 | 
            +
            portalocker==3.0.0
         | 
| 370 | 
            +
            portpicker==1.5.2
         | 
| 371 | 
            +
            preshed==3.0.9
         | 
| 372 | 
            +
            prettytable==3.12.0
         | 
| 373 | 
            +
            proglog==0.1.10
         | 
| 374 | 
            +
            progressbar2==4.5.0
         | 
| 375 | 
            +
            prometheus_client==0.21.1
         | 
| 376 | 
            +
            promise==2.3
         | 
| 377 | 
            +
            prompt_toolkit==3.0.48
         | 
| 378 | 
            +
            propcache==0.2.1
         | 
| 379 | 
            +
            prophet==1.1.6
         | 
| 380 | 
            +
            proto-plus==1.25.0
         | 
| 381 | 
            +
            protobuf==4.25.5
         | 
| 382 | 
            +
            psutil==5.9.5
         | 
| 383 | 
            +
            psycopg2==2.9.10
         | 
| 384 | 
            +
            ptyprocess==0.7.0
         | 
| 385 | 
            +
            py-cpuinfo==9.0.0
         | 
| 386 | 
            +
            py4j==0.10.9.7
         | 
| 387 | 
            +
            pyarrow==17.0.0
         | 
| 388 | 
            +
            pyasn1==0.6.1
         | 
| 389 | 
            +
            pyasn1_modules==0.4.1
         | 
| 390 | 
            +
            pycocotools==2.0.7
         | 
| 391 | 
            +
            pycparser==2.22
         | 
| 392 | 
            +
            pydantic==1.10.19
         | 
| 393 | 
            +
            pydantic_core==2.27.1
         | 
| 394 | 
            +
            pydata-google-auth==1.9.0
         | 
| 395 | 
            +
            pydot==3.0.3
         | 
| 396 | 
            +
            pydotplus==2.0.2
         | 
| 397 | 
            +
            PyDrive==1.3.1
         | 
| 398 | 
            +
            PyDrive2==1.21.3
         | 
| 399 | 
            +
            pyerfa==2.0.1.5
         | 
| 400 | 
            +
            pygame==2.6.1
         | 
| 401 | 
            +
            pygit2==1.16.0
         | 
| 402 | 
            +
            Pygments==2.18.0
         | 
| 403 | 
            +
            PyGObject==3.42.1
         | 
| 404 | 
            +
            PyJWT==2.10.1
         | 
| 405 | 
            +
            pylibcudf-cu12 @ https://pypi.nvidia.com/pylibcudf-cu12/pylibcudf_cu12-24.10.1-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
         | 
| 406 | 
            +
            pylibcugraph-cu12==24.10.0
         | 
| 407 | 
            +
            pylibraft-cu12==24.10.0
         | 
| 408 | 
            +
            pymc==5.19.1
         | 
| 409 | 
            +
            pymystem3==0.2.0
         | 
| 410 | 
            +
            pynvjitlink-cu12==0.4.0
         | 
| 411 | 
            +
            pyogrio==0.10.0
         | 
| 412 | 
            +
            Pyomo==6.8.2
         | 
| 413 | 
            +
            PyOpenGL==3.1.7
         | 
| 414 | 
            +
            pyOpenSSL==24.2.1
         | 
| 415 | 
            +
            pyparsing==3.2.0
         | 
| 416 | 
            +
            pyperclip==1.9.0
         | 
| 417 | 
            +
            pyproj==3.7.0
         | 
| 418 | 
            +
            pyshp==2.3.1
         | 
| 419 | 
            +
            PySocks==1.7.1
         | 
| 420 | 
            +
            pyspark==3.5.3
         | 
| 421 | 
            +
            pytensor==2.26.4
         | 
| 422 | 
            +
            pytest==8.3.4
         | 
| 423 | 
            +
            python-apt==0.0.0
         | 
| 424 | 
            +
            python-box==7.3.0
         | 
| 425 | 
            +
            python-dateutil==2.8.2
         | 
| 426 | 
            +
            python-louvain==0.16
         | 
| 427 | 
            +
            python-slugify==8.0.4
         | 
| 428 | 
            +
            python-utils==3.9.1
         | 
| 429 | 
            +
            pytz==2024.2
         | 
| 430 | 
            +
            pyviz_comms==3.0.3
         | 
| 431 | 
            +
            PyWavelets==1.8.0
         | 
| 432 | 
            +
            PyYAML==6.0.1
         | 
| 433 | 
            +
            pyzmq==24.0.1
         | 
| 434 | 
            +
            qdldl==0.1.7.post4
         | 
| 435 | 
            +
            ratelim==0.1.6
         | 
| 436 | 
            +
            referencing==0.35.1
         | 
| 437 | 
            +
            regex==2023.10.3
         | 
| 438 | 
            +
            requests==2.32.3
         | 
| 439 | 
            +
            requests-oauthlib==1.3.1
         | 
| 440 | 
            +
            requests-toolbelt==1.0.0
         | 
| 441 | 
            +
            requirements-parser==0.9.0
         | 
| 442 | 
            +
            rich==13.9.4
         | 
| 443 | 
            +
            rmm-cu12==24.10.0
         | 
| 444 | 
            +
            rpds-py==0.22.3
         | 
| 445 | 
            +
            rpy2==3.4.2
         | 
| 446 | 
            +
            rsa==4.9
         | 
| 447 | 
            +
            safetensors==0.4.5
         | 
| 448 | 
            +
            scikit-image==0.21.0
         | 
| 449 | 
            +
            scikit-learn==1.3.1
         | 
| 450 | 
            +
            scipy==1.13.1
         | 
| 451 | 
            +
            scooby==0.10.0
         | 
| 452 | 
            +
            scs==3.2.7
         | 
| 453 | 
            +
            seaborn==0.13.2
         | 
| 454 | 
            +
            SecretStorage==3.3.1
         | 
| 455 | 
            +
            Send2Trash==1.8.3
         | 
| 456 | 
            +
            sentence-transformers==3.3.1
         | 
| 457 | 
            +
            sentencepiece==0.1.99
         | 
| 458 | 
            +
            sentry-sdk==2.19.2
         | 
| 459 | 
            +
            setproctitle==1.3.4
         | 
| 460 | 
            +
            shap==0.46.0
         | 
| 461 | 
            +
            shapely==2.0.6
         | 
| 462 | 
            +
            shellingham==1.5.4
         | 
| 463 | 
            +
            simple-parsing==0.1.6
         | 
| 464 | 
            +
            six==1.17.0
         | 
| 465 | 
            +
            sklearn-pandas==2.2.0
         | 
| 466 | 
            +
            slicer==0.0.8
         | 
| 467 | 
            +
            smart-open==7.1.0
         | 
| 468 | 
            +
            smmap==5.0.1
         | 
| 469 | 
            +
            sniffio==1.3.1
         | 
| 470 | 
            +
            snowballstemmer==2.2.0
         | 
| 471 | 
            +
            soundfile==0.12.1
         | 
| 472 | 
            +
            soupsieve==2.6
         | 
| 473 | 
            +
            soxr==0.5.0.post1
         | 
| 474 | 
            +
            spacy==3.7.5
         | 
| 475 | 
            +
            spacy-legacy==3.0.12
         | 
| 476 | 
            +
            spacy-loggers==1.0.5
         | 
| 477 | 
            +
            Sphinx==8.1.3
         | 
| 478 | 
            +
            sphinxcontrib-applehelp==2.0.0
         | 
| 479 | 
            +
            sphinxcontrib-devhelp==2.0.0
         | 
| 480 | 
            +
            sphinxcontrib-htmlhelp==2.1.0
         | 
| 481 | 
            +
            sphinxcontrib-jsmath==1.0.1
         | 
| 482 | 
            +
            sphinxcontrib-qthelp==2.0.0
         | 
| 483 | 
            +
            sphinxcontrib-serializinghtml==2.0.0
         | 
| 484 | 
            +
            SQLAlchemy==2.0.36
         | 
| 485 | 
            +
            sqlglot==25.1.0
         | 
| 486 | 
            +
            sqlparse==0.5.3
         | 
| 487 | 
            +
            srsly==2.5.0
         | 
| 488 | 
            +
            stanio==0.5.1
         | 
| 489 | 
            +
            statsmodels==0.14.4
         | 
| 490 | 
            +
            StrEnum==0.4.15
         | 
| 491 | 
            +
            stringzilla==3.11.1
         | 
| 492 | 
            +
            sympy==1.13.1
         | 
| 493 | 
            +
            tables==3.10.1
         | 
| 494 | 
            +
            tabulate==0.9.0
         | 
| 495 | 
            +
            tbb==2022.0.0
         | 
| 496 | 
            +
            tcmlib==1.2.0
         | 
| 497 | 
            +
            tenacity==9.0.0
         | 
| 498 | 
            +
            tensorboard==2.17.1
         | 
| 499 | 
            +
            tensorboard-data-server==0.7.2
         | 
| 500 | 
            +
            tensorflow==2.17.1
         | 
| 501 | 
            +
            tensorflow-datasets==4.9.7
         | 
| 502 | 
            +
            tensorflow-hub==0.16.1
         | 
| 503 | 
            +
            tensorflow-io-gcs-filesystem==0.37.1
         | 
| 504 | 
            +
            tensorflow-metadata==1.13.1
         | 
| 505 | 
            +
            tensorflow-probability==0.24.0
         | 
| 506 | 
            +
            tensorstore==0.1.71
         | 
| 507 | 
            +
            termcolor==2.5.0
         | 
| 508 | 
            +
            terminado==0.18.1
         | 
| 509 | 
            +
            text-unidecode==1.3
         | 
| 510 | 
            +
            textblob==0.17.1
         | 
| 511 | 
            +
            tf-slim==1.1.0
         | 
| 512 | 
            +
            tf_keras==2.17.0
         | 
| 513 | 
            +
            thinc==8.2.5
         | 
| 514 | 
            +
            threadpoolctl==3.5.0
         | 
| 515 | 
            +
            tifffile==2024.12.12
         | 
| 516 | 
            +
            timm==0.4.12
         | 
| 517 | 
            +
            tinycss2==1.4.0
         | 
| 518 | 
            +
            tokenizers==0.14.1
         | 
| 519 | 
            +
            toml==0.10.2
         | 
| 520 | 
            +
            tomli==2.2.1
         | 
| 521 | 
            +
            toolz==0.12.1
         | 
| 522 | 
            +
            torch @ https://download.pytorch.org/whl/cu121_full/torch-2.5.1%2Bcu121-cp310-cp310-linux_x86_64.whl
         | 
| 523 | 
            +
            torchaudio @ https://download.pytorch.org/whl/cu121/torchaudio-2.5.1%2Bcu121-cp310-cp310-linux_x86_64.whl
         | 
| 524 | 
            +
            torchsummary==1.5.1
         | 
| 525 | 
            +
            torchvision @ https://download.pytorch.org/whl/cu121/torchvision-0.20.1%2Bcu121-cp310-cp310-linux_x86_64.whl
         | 
| 526 | 
            +
            tornado==6.3.3
         | 
| 527 | 
            +
            tqdm==4.67.1
         | 
| 528 | 
            +
            traitlets==5.7.1
         | 
| 529 | 
            +
            traittypes==0.2.1
         | 
| 530 | 
            +
            transformers==4.34.0
         | 
| 531 | 
            +
            tweepy==4.14.0
         | 
| 532 | 
            +
            typeguard==4.4.1
         | 
| 533 | 
            +
            typer==0.15.1
         | 
| 534 | 
            +
            types-pytz==2024.2.0.20241003
         | 
| 535 | 
            +
            types-setuptools==75.6.0.20241126
         | 
| 536 | 
            +
            typing_extensions==4.12.2
         | 
| 537 | 
            +
            tzdata==2024.2
         | 
| 538 | 
            +
            tzlocal==5.2
         | 
| 539 | 
            +
            uc-micro-py==1.0.3
         | 
| 540 | 
            +
            umf==0.9.1
         | 
| 541 | 
            +
            uritemplate==4.1.1
         | 
| 542 | 
            +
            urllib3==2.2.3
         | 
| 543 | 
            +
            vega-datasets==0.9.0
         | 
| 544 | 
            +
            vision-datasets==0.2.2
         | 
| 545 | 
            +
            wadllib==1.3.6
         | 
| 546 | 
            +
            wandb==0.19.1
         | 
| 547 | 
            +
            wasabi==1.1.3
         | 
| 548 | 
            +
            wcwidth==0.2.13
         | 
| 549 | 
            +
            weasel==0.4.1
         | 
| 550 | 
            +
            webcolors==24.11.1
         | 
| 551 | 
            +
            webencodings==0.5.1
         | 
| 552 | 
            +
            websocket-client==1.8.0
         | 
| 553 | 
            +
            websockets==14.1
         | 
| 554 | 
            +
            Werkzeug==3.1.3
         | 
| 555 | 
            +
            widgetsnbextension==3.6.10
         | 
| 556 | 
            +
            wordcloud==1.9.4
         | 
| 557 | 
            +
            wrapt==1.17.0
         | 
| 558 | 
            +
            xarray==2024.11.0
         | 
| 559 | 
            +
            xarray-einstats==0.8.0
         | 
| 560 | 
            +
            xgboost==2.1.3
         | 
| 561 | 
            +
            xlrd==2.0.1
         | 
| 562 | 
            +
            xyzservices==2024.9.0
         | 
| 563 | 
            +
            yacs==0.1.8
         | 
| 564 | 
            +
            yarl==1.18.3
         | 
| 565 | 
            +
            yellowbrick==1.5
         | 
| 566 | 
            +
            yfinance==0.2.50
         | 
| 567 | 
            +
            zipp==3.21.0
         | 
    	
        colabs/requirements-colab.txt
    ADDED
    
    | @@ -0,0 +1,39 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            pillow==9.4.0
         | 
| 2 | 
            +
            opencv-python==4.8.1.78
         | 
| 3 | 
            +
            pyyaml==6.0.1
         | 
| 4 | 
            +
            json_tricks==3.17.3
         | 
| 5 | 
            +
            yacs==0.1.8
         | 
| 6 | 
            +
            scikit-learn==1.3.1
         | 
| 7 | 
            +
            pandas==2.0.3
         | 
| 8 | 
            +
            timm==0.4.12
         | 
| 9 | 
            +
            numpy==1.26.4
         | 
| 10 | 
            +
            einops==0.8.0
         | 
| 11 | 
            +
            fvcore==0.1.5.post20221221
         | 
| 12 | 
            +
            transformers==4.34.0
         | 
| 13 | 
            +
            sentencepiece==0.1.99
         | 
| 14 | 
            +
            ftfy==6.1.1
         | 
| 15 | 
            +
            regex==2023.10.3
         | 
| 16 | 
            +
            nltk==3.8.1
         | 
| 17 | 
            +
            mpi4py==3.1.5
         | 
| 18 | 
            +
            vision-datasets==0.2.2
         | 
| 19 | 
            +
            cython==3.0.2
         | 
| 20 | 
            +
            pycocotools==2.0.7
         | 
| 21 | 
            +
            diffdist==0.1
         | 
| 22 | 
            +
            #pyarrow==13.0.0
         | 
| 23 | 
            +
            #cityscapesscripts==2.2.2
         | 
| 24 | 
            +
            #shapely==1.8.0
         | 
| 25 | 
            +
            scikit-image==0.21.0
         | 
| 26 | 
            +
            mup==1.0.0
         | 
| 27 | 
            +
            accelerate==0.23.0
         | 
| 28 | 
            +
            kornia==0.7.0
         | 
| 29 | 
            +
            deepspeed==0.10.3
         | 
| 30 | 
            +
            #wandb==0.15.12
         | 
| 31 | 
            +
            infinibatch==0.1.1
         | 
| 32 | 
            +
            open-clip-torch==2.26.1
         | 
| 33 | 
            +
            git+https://github.com/MaureenZOU/detectron2-xyz.git
         | 
| 34 | 
            +
            #gradio==3.42.0
         | 
| 35 | 
            +
            #torch==2.3.1 #2.0.1 
         | 
| 36 | 
            +
            #torchvision==0.15.2 
         | 
| 37 | 
            +
            #torchaudio==2.0.2
         | 
| 38 | 
            +
            #torch==2.1.0
         | 
| 39 | 
            +
            #torchvision==0.16.0
         | 
    	
        configs/biomedparse_inference.yaml
    ADDED
    
    | @@ -0,0 +1,204 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Define Test/Trainer/Saving
         | 
| 2 | 
            +
            PIPELINE: XDecoderPipeline
         | 
| 3 | 
            +
            TRAINER: xdecoder
         | 
| 4 | 
            +
            SAVE_DIR: "../../data/output/test"
         | 
| 5 | 
            +
            base_path: "./"
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            # Resume Logistic
         | 
| 8 | 
            +
            RESUME: false
         | 
| 9 | 
            +
            WEIGHT: false
         | 
| 10 | 
            +
            RESUME_FROM: ""
         | 
| 11 | 
            +
            EVAL_AT_START: false
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            # Logging and Debug
         | 
| 14 | 
            +
            WANDB: False
         | 
| 15 | 
            +
            LOG_EVERY: 100
         | 
| 16 | 
            +
            FIND_UNUSED_PARAMETERS: false
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            # Speed up training
         | 
| 19 | 
            +
            FP16: false
         | 
| 20 | 
            +
            PORT: "36873"
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            # misc
         | 
| 23 | 
            +
            LOADER:
         | 
| 24 | 
            +
              JOINT: False
         | 
| 25 | 
            +
              KEY_DATASET: "coco"
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            STANDARD_TEXT_FOR_EVAL: False
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            ##################
         | 
| 30 | 
            +
            # Task settings
         | 
| 31 | 
            +
            ##################
         | 
| 32 | 
            +
            VERBOSE: true
         | 
| 33 | 
            +
            MODEL:
         | 
| 34 | 
            +
              NAME: seem_model_demo
         | 
| 35 | 
            +
              HEAD: xdecoder_head
         | 
| 36 | 
            +
              DIM_PROJ: 512
         | 
| 37 | 
            +
              TEXT:
         | 
| 38 | 
            +
                ARCH: vlpencoder
         | 
| 39 | 
            +
                NAME: transformer
         | 
| 40 | 
            +
                TOKENIZER: clip
         | 
| 41 | 
            +
                CONTEXT_LENGTH: 77 # 77
         | 
| 42 | 
            +
                WIDTH: 512
         | 
| 43 | 
            +
                HEADS: 8
         | 
| 44 | 
            +
                LAYERS: 12 # 6
         | 
| 45 | 
            +
                AUTOGRESSIVE: True
         | 
| 46 | 
            +
              BACKBONE:
         | 
| 47 | 
            +
                NAME: focal
         | 
| 48 | 
            +
                PRETRAINED: ""
         | 
| 49 | 
            +
                LOAD_PRETRAINED: false
         | 
| 50 | 
            +
                FOCAL:
         | 
| 51 | 
            +
                  PRETRAIN_IMG_SIZE: 224
         | 
| 52 | 
            +
                  PATCH_SIZE: 4
         | 
| 53 | 
            +
                  EMBED_DIM: 192
         | 
| 54 | 
            +
                  DEPTHS: [2, 2, 18, 2]
         | 
| 55 | 
            +
                  FOCAL_LEVELS: [4, 4, 4, 4]
         | 
| 56 | 
            +
                  FOCAL_WINDOWS: [3, 3, 3, 3]
         | 
| 57 | 
            +
                  DROP_PATH_RATE: 0.3
         | 
| 58 | 
            +
                  MLP_RATIO: 4.0
         | 
| 59 | 
            +
                  DROP_RATE: 0.0
         | 
| 60 | 
            +
                  PATCH_NORM: True
         | 
| 61 | 
            +
                  USE_CONV_EMBED: True
         | 
| 62 | 
            +
                  SCALING_MODULATOR: True
         | 
| 63 | 
            +
                  USE_CHECKPOINT: False
         | 
| 64 | 
            +
                  USE_POSTLN: true
         | 
| 65 | 
            +
                  USE_POSTLN_IN_MODULATION: false
         | 
| 66 | 
            +
                  USE_LAYERSCALE: True
         | 
| 67 | 
            +
                  OUT_FEATURES: ["res2", "res3", "res4", "res5"]
         | 
| 68 | 
            +
                  OUT_INDICES: [0, 1, 2, 3]
         | 
| 69 | 
            +
              ENCODER:
         | 
| 70 | 
            +
                NAME: transformer_encoder_fpn
         | 
| 71 | 
            +
                IGNORE_VALUE: 255
         | 
| 72 | 
            +
                NUM_CLASSES: 16
         | 
| 73 | 
            +
                BINARY_CLASSES: False
         | 
| 74 | 
            +
                LOSS_WEIGHT: 1.0
         | 
| 75 | 
            +
                CONVS_DIM: 512
         | 
| 76 | 
            +
                MASK_DIM: 512
         | 
| 77 | 
            +
                NORM: "GN"
         | 
| 78 | 
            +
                IN_FEATURES: ["res2", "res3", "res4", "res5"]
         | 
| 79 | 
            +
                DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
         | 
| 80 | 
            +
                COMMON_STRIDE: 4
         | 
| 81 | 
            +
                TRANSFORMER_ENC_LAYERS: 6
         | 
| 82 | 
            +
              DECODER:
         | 
| 83 | 
            +
                NAME: seem_demo
         | 
| 84 | 
            +
                TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder"
         | 
| 85 | 
            +
                MASK:
         | 
| 86 | 
            +
                  ENABLED: False
         | 
| 87 | 
            +
                DETECTION: False
         | 
| 88 | 
            +
                SPATIAL:
         | 
| 89 | 
            +
                  ENABLED: True
         | 
| 90 | 
            +
                  MAX_ITER: 1
         | 
| 91 | 
            +
                GROUNDING:
         | 
| 92 | 
            +
                  ENABLED: True
         | 
| 93 | 
            +
                  MAX_LEN: 5
         | 
| 94 | 
            +
                  TEXT_WEIGHT: 2.0
         | 
| 95 | 
            +
                  CLASS_WEIGHT: 0.5
         | 
| 96 | 
            +
                VISUAL:
         | 
| 97 | 
            +
                  ENABLED: False
         | 
| 98 | 
            +
                AUDIO:
         | 
| 99 | 
            +
                  ENABLED: False
         | 
| 100 | 
            +
                RETRIEVAL:
         | 
| 101 | 
            +
                  ENABLED: False
         | 
| 102 | 
            +
                LVIS:
         | 
| 103 | 
            +
                  ENABLED: True
         | 
| 104 | 
            +
                  THRES: 0.7
         | 
| 105 | 
            +
                OPENIMAGE:
         | 
| 106 | 
            +
                  ENABLED: False
         | 
| 107 | 
            +
                  NEGATIVE_SAMPLES: 5
         | 
| 108 | 
            +
                  GROUNDING:
         | 
| 109 | 
            +
                    ENABLED: False
         | 
| 110 | 
            +
                    MAX_LEN: 5
         | 
| 111 | 
            +
                CAPTION:
         | 
| 112 | 
            +
                  ENABLED: False
         | 
| 113 | 
            +
                  PHRASE_PROB: 0.5
         | 
| 114 | 
            +
                  SIM_THRES: 0.95
         | 
| 115 | 
            +
                DEEP_SUPERVISION: True
         | 
| 116 | 
            +
                NO_OBJECT_WEIGHT: 0.1
         | 
| 117 | 
            +
                GCLASS_WEIGHT: 0.4
         | 
| 118 | 
            +
                GMASK_WEIGHT: 1.0
         | 
| 119 | 
            +
                GDICE_WEIGHT: 1.0
         | 
| 120 | 
            +
                SCLASS_WEIGHT: 0.4
         | 
| 121 | 
            +
                SMASK_WEIGHT: 1.0
         | 
| 122 | 
            +
                SDICE_WEIGHT: 1.0
         | 
| 123 | 
            +
                OCLASS_WEIGHT: 0.4
         | 
| 124 | 
            +
                OMASK_WEIGHT: 1.0
         | 
| 125 | 
            +
                ODICE_WEIGHT: 1.0
         | 
| 126 | 
            +
                CLASS_WEIGHT: 2.0
         | 
| 127 | 
            +
                MASK_WEIGHT: 5.0
         | 
| 128 | 
            +
                DICE_WEIGHT: 5.0
         | 
| 129 | 
            +
                BBOX_WEIGHT: 5.0
         | 
| 130 | 
            +
                GIOU_WEIGHT: 2.0
         | 
| 131 | 
            +
                CAPTION_WEIGHT: 2.0
         | 
| 132 | 
            +
                COST_SPATIAL:
         | 
| 133 | 
            +
                  CLASS_WEIGHT: 5.0
         | 
| 134 | 
            +
                  MASK_WEIGHT: 2.0
         | 
| 135 | 
            +
                  DICE_WEIGHT: 2.0
         | 
| 136 | 
            +
                HIDDEN_DIM: 512
         | 
| 137 | 
            +
                NUM_OBJECT_QUERIES: 101
         | 
| 138 | 
            +
                NHEADS: 8
         | 
| 139 | 
            +
                DROPOUT: 0.0
         | 
| 140 | 
            +
                DIM_FEEDFORWARD: 2048
         | 
| 141 | 
            +
                MAX_SPATIAL_LEN: [512, 512, 512, 512]
         | 
| 142 | 
            +
                # ENC_LAYERS: 0
         | 
| 143 | 
            +
                PRE_NORM: False
         | 
| 144 | 
            +
                ENFORCE_INPUT_PROJ: False
         | 
| 145 | 
            +
                SIZE_DIVISIBILITY: 32
         | 
| 146 | 
            +
                TRAIN_NUM_POINTS: 12544
         | 
| 147 | 
            +
                OVERSAMPLE_RATIO: 3.0
         | 
| 148 | 
            +
                IMPORTANCE_SAMPLE_RATIO: 0.75
         | 
| 149 | 
            +
                DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query
         | 
| 150 | 
            +
                TOP_GROUNDING_LAYERS: 10
         | 
| 151 | 
            +
                TOP_CAPTION_LAYERS: 10
         | 
| 152 | 
            +
                TOP_SPATIAL_LAYERS: 10
         | 
| 153 | 
            +
                TOP_OPENIMAGE_LAYERS: 10
         | 
| 154 | 
            +
                TEST:
         | 
| 155 | 
            +
                  SEMANTIC_ON: True
         | 
| 156 | 
            +
                  INSTANCE_ON: True
         | 
| 157 | 
            +
                  PANOPTIC_ON: True
         | 
| 158 | 
            +
                  OVERLAP_THRESHOLD: 0.8
         | 
| 159 | 
            +
                  OBJECT_MASK_THRESHOLD: 0.4
         | 
| 160 | 
            +
                  SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: false
         | 
| 161 | 
            +
                  DETECTIONS_PER_IMAGE: 100
         | 
| 162 | 
            +
             | 
| 163 | 
            +
            # Multi-modal Architecture, order matters
         | 
| 164 | 
            +
            ATTENTION_ARCH:
         | 
| 165 | 
            +
              VARIABLE:
         | 
| 166 | 
            +
                queries: ["object"]
         | 
| 167 | 
            +
                tokens: ["grounding", "spatial", "visual", "audio"]
         | 
| 168 | 
            +
              SELF_ATTENTION:
         | 
| 169 | 
            +
                queries:
         | 
| 170 | 
            +
                  object:
         | 
| 171 | 
            +
                    [
         | 
| 172 | 
            +
                      "queries_object",
         | 
| 173 | 
            +
                      "tokens_grounding",
         | 
| 174 | 
            +
                      "tokens_spatial",
         | 
| 175 | 
            +
                      "tokens_visual",
         | 
| 176 | 
            +
                      "tokens_audio",
         | 
| 177 | 
            +
                    ]
         | 
| 178 | 
            +
                tokens:
         | 
| 179 | 
            +
                  grounding: ["queries_object", "tokens_grounding"]
         | 
| 180 | 
            +
                  spatial: ["tokens_spatial"]
         | 
| 181 | 
            +
                  visual: ["tokens_visual"]
         | 
| 182 | 
            +
                  audio: ["queries_object", "tokens_audio"]
         | 
| 183 | 
            +
              CROSS_ATTENTION:
         | 
| 184 | 
            +
                queries:
         | 
| 185 | 
            +
                  object: True
         | 
| 186 | 
            +
                tokens:
         | 
| 187 | 
            +
                  grounding: False
         | 
| 188 | 
            +
                  spatial: False
         | 
| 189 | 
            +
                  visual: False
         | 
| 190 | 
            +
                  audio: False
         | 
| 191 | 
            +
              MASKING:
         | 
| 192 | 
            +
                ["tokens_spatial", "tokens_grounding", "tokens_visual", "tokens_audio"]
         | 
| 193 | 
            +
              DUPLICATION:
         | 
| 194 | 
            +
                queries:
         | 
| 195 | 
            +
                  grounding: "queries_object"
         | 
| 196 | 
            +
                  spatial: "queries_object"
         | 
| 197 | 
            +
              SPATIAL_MEMORIES: 32
         | 
| 198 | 
            +
             | 
| 199 | 
            +
            INPUT:
         | 
| 200 | 
            +
              PIXEL_MEAN: [123.675, 116.280, 103.530]
         | 
| 201 | 
            +
              PIXEL_STD: [58.395, 57.120, 57.375]
         | 
| 202 | 
            +
            # INPUT:
         | 
| 203 | 
            +
            #   PIXEL_MEAN: [64.284, 59.293, 59.962]
         | 
| 204 | 
            +
            #   PIXEL_STD: [62.484, 60.865, 59.835]
         | 
    	
        entrypoint.sh
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/bin/bash
         | 
| 2 | 
            +
            if [ -f "/run/secrets/HF_TOKEN" ]; then
         | 
| 3 | 
            +
                export HF_TOKEN=$(cat /run/secrets/HF_TOKEN)
         | 
| 4 | 
            +
            fi
         | 
| 5 | 
            +
            exec conda run --no-capture-output -n biomedparse python main.py
         | 
    	
        examples/Part_1_516_pathology_breast.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        inference_utils/inference.py
    ADDED
    
    | @@ -0,0 +1,149 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
            from PIL import Image
         | 
| 5 | 
            +
            from torchvision import transforms
         | 
| 6 | 
            +
            #from utils.visualizer import Visualizer
         | 
| 7 | 
            +
            # from detectron2.utils.colormap import random_color
         | 
| 8 | 
            +
            # from detectron2.data import MetadataCatalog
         | 
| 9 | 
            +
            # from detectron2.structures import BitMasks
         | 
| 10 | 
            +
            from modeling.language.loss import vl_similarity
         | 
| 11 | 
            +
            from utilities.constants import BIOMED_CLASSES
         | 
| 12 | 
            +
            #from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            # import cv2
         | 
| 15 | 
            +
            # import os
         | 
| 16 | 
            +
            # import glob
         | 
| 17 | 
            +
            # import subprocess
         | 
| 18 | 
            +
            from PIL import Image
         | 
| 19 | 
            +
            import random
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            t = []
         | 
| 22 | 
            +
            t.append(transforms.Resize((1024, 1024), interpolation=Image.BICUBIC))
         | 
| 23 | 
            +
            transform = transforms.Compose(t)
         | 
| 24 | 
            +
            #metadata = MetadataCatalog.get('coco_2017_train_panoptic')
         | 
| 25 | 
            +
            all_classes = ['background'] + [name.replace('-other','').replace('-merged','') 
         | 
| 26 | 
            +
                                            for name in BIOMED_CLASSES] + ["others"]
         | 
| 27 | 
            +
            # colors_list = [(np.array(color['color'])/255).tolist() for color in COCO_CATEGORIES] + [[1, 1, 1]]
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            # use color list from matplotlib
         | 
| 30 | 
            +
            import matplotlib.colors as mcolors
         | 
| 31 | 
            +
            colors = dict(mcolors.TABLEAU_COLORS, **mcolors.BASE_COLORS)
         | 
| 32 | 
            +
            colors_list = [list(colors.values())[i] for i in range(16)] 
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            from .output_processing import mask_stats, combine_masks
         | 
| 35 | 
            +
                
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            @torch.no_grad()
         | 
| 38 | 
            +
            def interactive_infer_image(model, image, prompts):
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                image_resize = transform(image)
         | 
| 41 | 
            +
                width = image.size[0]
         | 
| 42 | 
            +
                height = image.size[1]
         | 
| 43 | 
            +
                image_resize = np.asarray(image_resize)
         | 
| 44 | 
            +
                image = torch.from_numpy(image_resize.copy()).permute(2,0,1).cuda()
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                data = {"image": image, 'text': prompts, "height": height, "width": width}
         | 
| 47 | 
            +
                
         | 
| 48 | 
            +
                # inistalize task
         | 
| 49 | 
            +
                model.model.task_switch['spatial'] = False
         | 
| 50 | 
            +
                model.model.task_switch['visual'] = False
         | 
| 51 | 
            +
                model.model.task_switch['grounding'] = True
         | 
| 52 | 
            +
                model.model.task_switch['audio'] = False
         | 
| 53 | 
            +
                model.model.task_switch['grounding'] = True
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
                batch_inputs = [data]
         | 
| 57 | 
            +
                results,image_size,extra = model.model.evaluate_demo(batch_inputs)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                pred_masks = results['pred_masks'][0]
         | 
| 60 | 
            +
                v_emb = results['pred_captions'][0]
         | 
| 61 | 
            +
                t_emb = extra['grounding_class']
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 64 | 
            +
                v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                temperature = model.model.sem_seg_head.predictor.lang_encoder.logit_scale
         | 
| 67 | 
            +
                out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
         | 
| 68 | 
            +
                
         | 
| 69 | 
            +
                matched_id = out_prob.max(0)[1]
         | 
| 70 | 
            +
                pred_masks_pos = pred_masks[matched_id,:,:]
         | 
| 71 | 
            +
                pred_class = results['pred_logits'][0][matched_id].max(dim=-1)[1]
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                # interpolate mask to ori size
         | 
| 74 | 
            +
                pred_mask_prob = F.interpolate(pred_masks_pos[None,], (data['height'], data['width']), 
         | 
| 75 | 
            +
                                               mode='bilinear')[0,:,:data['height'],:data['width']].sigmoid().cpu().numpy()
         | 
| 76 | 
            +
                pred_masks_pos = (1*(pred_mask_prob > 0.5)).astype(np.uint8)
         | 
| 77 | 
            +
                
         | 
| 78 | 
            +
                return pred_mask_prob
         | 
| 79 | 
            +
             | 
| 80 | 
            +
             | 
| 81 | 
            +
             | 
| 82 | 
            +
            # def interactive_infer_panoptic_biomedseg(model, image, tasks, reftxt=None):
         | 
| 83 | 
            +
            #     image_ori = transform(image)
         | 
| 84 | 
            +
            #     #mask_ori = image['mask']
         | 
| 85 | 
            +
            #     width = image_ori.size[0]
         | 
| 86 | 
            +
            #     height = image_ori.size[1]
         | 
| 87 | 
            +
            #     image_ori = np.asarray(image_ori)
         | 
| 88 | 
            +
            #     visual = Visualizer(image_ori, metadata=metadata)
         | 
| 89 | 
            +
            #     images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
         | 
| 90 | 
            +
             | 
| 91 | 
            +
            #     data = {"image": images, "height": height, "width": width}
         | 
| 92 | 
            +
            #     if len(tasks) == 0:
         | 
| 93 | 
            +
            #         tasks = ["Panoptic"]
         | 
| 94 | 
            +
                
         | 
| 95 | 
            +
            #     # inistalize task
         | 
| 96 | 
            +
            #     model.model.task_switch['spatial'] = False
         | 
| 97 | 
            +
            #     model.model.task_switch['visual'] = False
         | 
| 98 | 
            +
            #     model.model.task_switch['grounding'] = False
         | 
| 99 | 
            +
            #     model.model.task_switch['audio'] = False
         | 
| 100 | 
            +
             | 
| 101 | 
            +
            #     # check if reftxt is list of strings
         | 
| 102 | 
            +
            #     assert isinstance(reftxt, list), f"reftxt should be a list of strings, but got {type(reftxt)}"
         | 
| 103 | 
            +
            #     model.model.task_switch['grounding'] = True
         | 
| 104 | 
            +
            #     predicts = {}
         | 
| 105 | 
            +
            #     for i, txt in enumerate(reftxt): 
         | 
| 106 | 
            +
            #         data['text'] = txt
         | 
| 107 | 
            +
            #         batch_inputs = [data]
         | 
| 108 | 
            +
             | 
| 109 | 
            +
            #         results,image_size,extra = model.model.evaluate_demo(batch_inputs)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
            #         pred_masks = results['pred_masks'][0]
         | 
| 112 | 
            +
            #         v_emb = results['pred_captions'][0]
         | 
| 113 | 
            +
            #         t_emb = extra['grounding_class']
         | 
| 114 | 
            +
             | 
| 115 | 
            +
            #         t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 116 | 
            +
            #         v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
            #         temperature = model.model.sem_seg_head.predictor.lang_encoder.logit_scale
         | 
| 119 | 
            +
            #         out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
         | 
| 120 | 
            +
                    
         | 
| 121 | 
            +
            #         matched_id = out_prob.max(0)[1]
         | 
| 122 | 
            +
            #         pred_masks_pos = pred_masks[matched_id,:,:]
         | 
| 123 | 
            +
            #         pred_class = results['pred_logits'][0][matched_id].max(dim=-1)[1]
         | 
| 124 | 
            +
             | 
| 125 | 
            +
             | 
| 126 | 
            +
            #         # interpolate mask to ori size
         | 
| 127 | 
            +
            #         #pred_masks_pos = (F.interpolate(pred_masks_pos[None,], image_size[-2:], mode='bilinear')[0,:,:data['height'],:data['width']] > 0.0).float().cpu().numpy()
         | 
| 128 | 
            +
            #         # masks.append(pred_masks_pos[0])
         | 
| 129 | 
            +
            #         # mask = pred_masks_pos[0]
         | 
| 130 | 
            +
            #         # masks.append(mask)
         | 
| 131 | 
            +
            #         # interpolate mask to ori size
         | 
| 132 | 
            +
            #         pred_mask_prob = F.interpolate(pred_masks_pos[None,], image_size[-2:], mode='bilinear')[0,:,:data['height'],:data['width']].sigmoid().cpu().numpy()
         | 
| 133 | 
            +
            #         #pred_masks_pos = 1*(pred_mask_prob > 0.5)
         | 
| 134 | 
            +
            #         predicts[txt] = pred_mask_prob[0]
         | 
| 135 | 
            +
                    
         | 
| 136 | 
            +
            #     masks = combine_masks(predicts)
         | 
| 137 | 
            +
                    
         | 
| 138 | 
            +
            #     predict_mask_stats = {}
         | 
| 139 | 
            +
            #     print(masks.keys())
         | 
| 140 | 
            +
            #     for i, txt in enumerate(masks):
         | 
| 141 | 
            +
            #         mask = masks[txt]
         | 
| 142 | 
            +
            #         demo = visual.draw_binary_mask(mask, color=colors_list[i], text=txt)
         | 
| 143 | 
            +
            #         predict_mask_stats[txt] = mask_stats((predicts[txt]*255), image_ori)
         | 
| 144 | 
            +
                    
         | 
| 145 | 
            +
            #     res = demo.get_image()
         | 
| 146 | 
            +
            #     torch.cuda.empty_cache()
         | 
| 147 | 
            +
            #     # return Image.fromarray(res), stroke_inimg, stroke_refimg
         | 
| 148 | 
            +
            #     return Image.fromarray(res), None, predict_mask_stats
         | 
| 149 | 
            +
             | 
    	
        inference_utils/output_processing.py
    ADDED
    
    | @@ -0,0 +1,91 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import json
         | 
| 2 | 
            +
            from scipy import stats
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import huggingface_hub
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            def check_mask_stats(img, mask, modality_type, target):
         | 
| 9 | 
            +
                # img: np.array, shape=(H, W, 3) RGB image with pixel values in [0, 255]
         | 
| 10 | 
            +
                # mask: np.array, shape=(H, W, 1) mask probability scaled to [0,255] with pixel values in [0, 255]
         | 
| 11 | 
            +
                # modality_type: str, see target_dist.json for the list of modality types
         | 
| 12 | 
            +
                # target: str, see target_dist.json for the list of targets
         | 
| 13 | 
            +
                
         | 
| 14 | 
            +
                huggingface_hub.hf_hub_download('microsoft/BiomedParse', filename='target_dist.json', local_dir='./inference_utils')
         | 
| 15 | 
            +
                huggingface_hub.hf_hub_download('microsoft/BiomedParse', filename="config.yaml", local_dir="./configs")
         | 
| 16 | 
            +
                target_dist = json.load(open("inference_utils/target_dist.json"))
         | 
| 17 | 
            +
                
         | 
| 18 | 
            +
                if modality_type not in target_dist:
         | 
| 19 | 
            +
                    raise ValueError(f"Currently support modality types: {list(target_dist.keys())}")
         | 
| 20 | 
            +
                
         | 
| 21 | 
            +
                if target not in target_dist[modality_type]:
         | 
| 22 | 
            +
                    raise ValueError(f"Currently support targets for {modality_type}: {list(target_dist[modality_type].keys())}")
         | 
| 23 | 
            +
                
         | 
| 24 | 
            +
                ms = mask_stats(mask, img)
         | 
| 25 | 
            +
                
         | 
| 26 | 
            +
                ps = [stats.ks_1samp([ms[i]], stats.beta(param[0], param[1]).cdf).pvalue for i, param in enumerate(target_dist[modality_type][target])]
         | 
| 27 | 
            +
                p_value = np.prod(ps)
         | 
| 28 | 
            +
                
         | 
| 29 | 
            +
                adj_p_value = p_value**0.24    # adjustment for four test products
         | 
| 30 | 
            +
                
         | 
| 31 | 
            +
                return adj_p_value
         | 
| 32 | 
            +
                
         | 
| 33 | 
            +
                
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            def mask_stats(mask, img):
         | 
| 36 | 
            +
                # mask is a prediction mask with pixel values in [0, 255] for probability in [0, 1]
         | 
| 37 | 
            +
                # img is a RGB image with pixel values in [0, 255]
         | 
| 38 | 
            +
                if mask.max() <= 127:
         | 
| 39 | 
            +
                    return [0, 0, 0, 0]
         | 
| 40 | 
            +
                return [mask[mask>=128].mean()/256, img[:,:,0][mask>=128].mean()/256, 
         | 
| 41 | 
            +
                        img[:,:,1][mask>=128].mean()/256, img[:,:,2][mask>=128].mean()/256]
         | 
| 42 | 
            +
                
         | 
| 43 | 
            +
                
         | 
| 44 | 
            +
                
         | 
| 45 | 
            +
            def combine_masks(predicts):
         | 
| 46 | 
            +
                # predicts: a dictionary of pixel probability, {TARGET: pred_prob}
         | 
| 47 | 
            +
                pixel_preds = {}
         | 
| 48 | 
            +
                target_area = {}
         | 
| 49 | 
            +
                target_probs = {}
         | 
| 50 | 
            +
                for target in predicts:
         | 
| 51 | 
            +
                    pred = predicts[target]
         | 
| 52 | 
            +
                    pred_region = np.where(pred > 0.1)
         | 
| 53 | 
            +
                    target_area[target] = 0
         | 
| 54 | 
            +
                    target_probs[target] = 0
         | 
| 55 | 
            +
                    for (i,j) in zip(*pred_region):
         | 
| 56 | 
            +
                        if (i,j) not in pixel_preds:
         | 
| 57 | 
            +
                            pixel_preds[(i,j)] = {}
         | 
| 58 | 
            +
                        pixel_preds[(i,j)][target] = pred[i,j]
         | 
| 59 | 
            +
                        target_area[target] += 1
         | 
| 60 | 
            +
                        target_probs[target] += pred[i,j]
         | 
| 61 | 
            +
                for target in predicts:
         | 
| 62 | 
            +
                    if target_area[target] == 0:
         | 
| 63 | 
            +
                        continue
         | 
| 64 | 
            +
                    target_probs[target] /= target_area[target]
         | 
| 65 | 
            +
                
         | 
| 66 | 
            +
                # generate combined masks
         | 
| 67 | 
            +
                combined_areas = {t: 0 for t in predicts}
         | 
| 68 | 
            +
                for index in pixel_preds:
         | 
| 69 | 
            +
                    pred_target = sorted(pixel_preds[index].keys(), key=lambda t: pixel_preds[index][t], reverse=True)[0]
         | 
| 70 | 
            +
                    combined_areas[pred_target] += 1
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                # discard targets with small areas
         | 
| 73 | 
            +
                discard_targets = []
         | 
| 74 | 
            +
                for target in predicts:
         | 
| 75 | 
            +
                    if combined_areas[target] < 0.6 * target_area[target]:
         | 
| 76 | 
            +
                        discard_targets.append(target)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                # keep the most confident target
         | 
| 79 | 
            +
                most_confident_target = sorted(predicts.keys(), key=lambda t: target_probs[t], reverse=True)[0]
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                discard_targets = [t for t in discard_targets if t != most_confident_target]
         | 
| 82 | 
            +
                
         | 
| 83 | 
            +
                masks = {t: np.zeros_like(predicts[t]).astype(np.uint8) for t in predicts if t not in discard_targets}
         | 
| 84 | 
            +
                for index in pixel_preds:
         | 
| 85 | 
            +
                    candidates = [t for t in pixel_preds[index] if t not in discard_targets and pixel_preds[index][t] > 0.5]
         | 
| 86 | 
            +
                    if len(candidates) == 0:
         | 
| 87 | 
            +
                        continue
         | 
| 88 | 
            +
                    pred_target = max(candidates, key=lambda t: pixel_preds[index][t])
         | 
| 89 | 
            +
                    masks[pred_target][index[0], index[1]] = 1
         | 
| 90 | 
            +
                
         | 
| 91 | 
            +
                return masks
         | 
    	
        inference_utils/processing_utils.py
    ADDED
    
    | @@ -0,0 +1,182 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            from skimage import transform
         | 
| 3 | 
            +
            import pydicom
         | 
| 4 | 
            +
            from io import BytesIO
         | 
| 5 | 
            +
            from PIL import Image
         | 
| 6 | 
            +
            import nibabel as nib
         | 
| 7 | 
            +
            import SimpleITK as sitk
         | 
| 8 | 
            +
            from skimage import measure
         | 
| 9 | 
            +
                
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            """
         | 
| 12 | 
            +
                This script contains utility functions for reading and processing different imaging modalities.
         | 
| 13 | 
            +
            """
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            CT_WINDOWS = {'abdomen': [-150, 250],
         | 
| 17 | 
            +
                          'lung': [-1000, 1000],
         | 
| 18 | 
            +
                          'pelvis': [-55, 200],
         | 
| 19 | 
            +
                          'liver': [-25, 230],
         | 
| 20 | 
            +
                          'colon': [-68, 187],
         | 
| 21 | 
            +
                          'pancreas': [-100, 200]}
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            def process_intensity_image(image_data, is_CT, site=None):
         | 
| 24 | 
            +
                # process intensity-based image. If CT, apply site specific windowing
         | 
| 25 | 
            +
                
         | 
| 26 | 
            +
                # image_data: 2D numpy array of shape (H, W)
         | 
| 27 | 
            +
                
         | 
| 28 | 
            +
                # return: 3-channel numpy array of shape (H, W, 3) as model input
         | 
| 29 | 
            +
                
         | 
| 30 | 
            +
                if is_CT:
         | 
| 31 | 
            +
                    # process image with windowing
         | 
| 32 | 
            +
                    if site and site in CT_WINDOWS:
         | 
| 33 | 
            +
                        window = CT_WINDOWS[site]
         | 
| 34 | 
            +
                    else:
         | 
| 35 | 
            +
                        raise ValueError(f'Please choose CT site from {CT_WINDOWS.keys()}')
         | 
| 36 | 
            +
                    lower_bound, upper_bound = window
         | 
| 37 | 
            +
                else:
         | 
| 38 | 
            +
                    # process image with intensity range 0.5-99.5 percentile
         | 
| 39 | 
            +
                    lower_bound, upper_bound = np.percentile(
         | 
| 40 | 
            +
                        image_data[image_data > 0], 0.5
         | 
| 41 | 
            +
                    ), np.percentile(image_data[image_data > 0], 99.5)
         | 
| 42 | 
            +
                    
         | 
| 43 | 
            +
                image_data_pre = np.clip(image_data, lower_bound, upper_bound)
         | 
| 44 | 
            +
                image_data_pre = (
         | 
| 45 | 
            +
                    (image_data_pre - image_data_pre.min())
         | 
| 46 | 
            +
                    / (image_data_pre.max() - image_data_pre.min())
         | 
| 47 | 
            +
                    * 255.0
         | 
| 48 | 
            +
                )
         | 
| 49 | 
            +
                
         | 
| 50 | 
            +
                # pad to square with equal padding on both sides
         | 
| 51 | 
            +
                shape = image_data_pre.shape
         | 
| 52 | 
            +
                if shape[0] > shape[1]:
         | 
| 53 | 
            +
                    pad = (shape[0]-shape[1])//2
         | 
| 54 | 
            +
                    pad_width = ((0,0), (pad, pad))
         | 
| 55 | 
            +
                elif shape[0] < shape[1]:
         | 
| 56 | 
            +
                    pad = (shape[1]-shape[0])//2
         | 
| 57 | 
            +
                    pad_width = ((pad, pad), (0,0))
         | 
| 58 | 
            +
                else:
         | 
| 59 | 
            +
                    pad_width = None
         | 
| 60 | 
            +
                
         | 
| 61 | 
            +
                if pad_width is not None:
         | 
| 62 | 
            +
                    image_data_pre = np.pad(image_data_pre, pad_width, 'constant', constant_values=0)
         | 
| 63 | 
            +
                    
         | 
| 64 | 
            +
                # resize image to 1024x1024
         | 
| 65 | 
            +
                image_size = 1024
         | 
| 66 | 
            +
                resize_image = transform.resize(image_data_pre, (image_size, image_size), order=3, 
         | 
| 67 | 
            +
                                                mode='constant', preserve_range=True, anti_aliasing=True)
         | 
| 68 | 
            +
                
         | 
| 69 | 
            +
                # convert to 3-channel image
         | 
| 70 | 
            +
                resize_image = np.stack([resize_image]*3, axis=-1)
         | 
| 71 | 
            +
                    
         | 
| 72 | 
            +
                return resize_image.astype(np.uint8)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            def read_dicom(image_path, is_CT, site=None):
         | 
| 77 | 
            +
                # read dicom file and return pixel data
         | 
| 78 | 
            +
                
         | 
| 79 | 
            +
                # dicom_file: str, path to dicom file
         | 
| 80 | 
            +
                # is_CT: bool, whether image is CT or not
         | 
| 81 | 
            +
                # site: str, one of CT_WINDOWS.keys()
         | 
| 82 | 
            +
                # return: 2D numpy array of shape (H, W)
         | 
| 83 | 
            +
                
         | 
| 84 | 
            +
                ds = pydicom.dcmread(image_path)
         | 
| 85 | 
            +
                image_array = ds.pixel_array * ds.RescaleSlope + ds.RescaleIntercept
         | 
| 86 | 
            +
                
         | 
| 87 | 
            +
                image_array = process_intensity_image(image_array, is_CT, site)
         | 
| 88 | 
            +
                
         | 
| 89 | 
            +
                return image_array
         | 
| 90 | 
            +
             | 
| 91 | 
            +
             | 
| 92 | 
            +
            def read_nifti(image_path, is_CT, slice_idx, site=None, HW_index=(0, 1), channel_idx=None):
         | 
| 93 | 
            +
                # read nifti file and return pixel data
         | 
| 94 | 
            +
                
         | 
| 95 | 
            +
                # image_path: str, path to nifti file
         | 
| 96 | 
            +
                # is_CT: bool, whether image is CT or not
         | 
| 97 | 
            +
                # slice_idx: int, slice index to read
         | 
| 98 | 
            +
                # site: str, one of CT_WINDOWS.keys()
         | 
| 99 | 
            +
                # HW_index: tuple, index of height and width in the image shape
         | 
| 100 | 
            +
                # return: 2D numpy array of shape (H, W)
         | 
| 101 | 
            +
                
         | 
| 102 | 
            +
                
         | 
| 103 | 
            +
                nii = nib.load(image_path)
         | 
| 104 | 
            +
                image_array = nii.get_fdata()
         | 
| 105 | 
            +
                
         | 
| 106 | 
            +
                if HW_index != (0, 1):
         | 
| 107 | 
            +
                    image_array = np.moveaxis(image_array, HW_index, (0, 1))
         | 
| 108 | 
            +
                
         | 
| 109 | 
            +
                # get slice
         | 
| 110 | 
            +
                if channel_idx is None:
         | 
| 111 | 
            +
                    image_array = image_array[:, :, slice_idx]
         | 
| 112 | 
            +
                else:
         | 
| 113 | 
            +
                    image_array = image_array[:, :, slice_idx, channel_idx]
         | 
| 114 | 
            +
                    
         | 
| 115 | 
            +
                image_array = process_intensity_image(image_array, is_CT, site)
         | 
| 116 | 
            +
                return image_array
         | 
| 117 | 
            +
                
         | 
| 118 | 
            +
             | 
| 119 | 
            +
             | 
| 120 | 
            +
            def read_rgb(image_path):
         | 
| 121 | 
            +
                # read RGB image and return resized pixel data
         | 
| 122 | 
            +
                
         | 
| 123 | 
            +
                # image_path: str, path to RGB image
         | 
| 124 | 
            +
                # return: BytesIO buffer
         | 
| 125 | 
            +
                
         | 
| 126 | 
            +
                # read image into numpy array
         | 
| 127 | 
            +
                image = Image.open(image_path)
         | 
| 128 | 
            +
                image = np.array(image)
         | 
| 129 | 
            +
                if len(image.shape) == 2:
         | 
| 130 | 
            +
                    image = np.stack([image]*3, axis=-1)
         | 
| 131 | 
            +
                elif image.shape[2] == 4:
         | 
| 132 | 
            +
                    image = image[:,:,:3]
         | 
| 133 | 
            +
                
         | 
| 134 | 
            +
                # pad to square with equal padding on both sides
         | 
| 135 | 
            +
                shape = image.shape
         | 
| 136 | 
            +
                if shape[0] > shape[1]:
         | 
| 137 | 
            +
                    pad = (shape[0]-shape[1])//2
         | 
| 138 | 
            +
                    pad_width = ((0,0), (pad, pad), (0,0))
         | 
| 139 | 
            +
                elif shape[0] < shape[1]:
         | 
| 140 | 
            +
                    pad = (shape[1]-shape[0])//2
         | 
| 141 | 
            +
                    pad_width = ((pad, pad), (0,0), (0,0))
         | 
| 142 | 
            +
                else:
         | 
| 143 | 
            +
                    pad_width = None
         | 
| 144 | 
            +
                    
         | 
| 145 | 
            +
                if pad_width is not None:
         | 
| 146 | 
            +
                    image = np.pad(image, pad_width, 'constant', constant_values=0)
         | 
| 147 | 
            +
                    
         | 
| 148 | 
            +
                # resize image to 1024x1024 for each channel
         | 
| 149 | 
            +
                image_size = 1024
         | 
| 150 | 
            +
                resize_image = np.zeros((image_size, image_size, 3), dtype=np.uint8)
         | 
| 151 | 
            +
                for i in range(3):
         | 
| 152 | 
            +
                    resize_image[:,:,i] = transform.resize(image[:,:,i], (image_size, image_size), order=3, 
         | 
| 153 | 
            +
                                                mode='constant', preserve_range=True, anti_aliasing=True)
         | 
| 154 | 
            +
                    
         | 
| 155 | 
            +
                return resize_image
         | 
| 156 | 
            +
             | 
| 157 | 
            +
             | 
| 158 | 
            +
             | 
| 159 | 
            +
            def get_instances(mask):
         | 
| 160 | 
            +
                # get intances from binary mask
         | 
| 161 | 
            +
                seg = sitk.GetImageFromArray(mask)
         | 
| 162 | 
            +
                filled = sitk.BinaryFillhole(seg)
         | 
| 163 | 
            +
                d = sitk.SignedMaurerDistanceMap(filled, insideIsPositive=False, squaredDistance=False, useImageSpacing=False)
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                ws = sitk.MorphologicalWatershed( d, markWatershedLine=False, level=1)
         | 
| 166 | 
            +
                ws = sitk.Mask( ws, sitk.Cast(seg, ws.GetPixelID()))
         | 
| 167 | 
            +
                ins_mask = sitk.GetArrayFromImage(ws)
         | 
| 168 | 
            +
                
         | 
| 169 | 
            +
                # filter out instances with small area outliers
         | 
| 170 | 
            +
                props = measure.regionprops_table(ins_mask, properties=('label', 'area'))
         | 
| 171 | 
            +
                mean_area = np.mean(props['area'])
         | 
| 172 | 
            +
                std_area = np.std(props['area'])
         | 
| 173 | 
            +
                
         | 
| 174 | 
            +
                threshold = mean_area - 2*std_area - 1
         | 
| 175 | 
            +
                ins_mask_filtered = ins_mask.copy()
         | 
| 176 | 
            +
                for i, area in zip(props['label'], props['area']):
         | 
| 177 | 
            +
                    if area < threshold:
         | 
| 178 | 
            +
                        ins_mask_filtered[ins_mask == i] = 0
         | 
| 179 | 
            +
                        
         | 
| 180 | 
            +
                return ins_mask_filtered
         | 
| 181 | 
            +
                
         | 
| 182 | 
            +
                
         | 
    	
        inference_utils/target_dist.json
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            {"CT-Abdomen": {"postcava": [[244.8001455798728, 5.314270814858824], [7.183679633251858, 5.168810995426391], [7.183679633251858, 5.168810995426391], [7.183679633251858, 5.168810995426391]], "aorta": [[570.5260544851909, 8.97527503179567], [3.3715049586348242, 1.4971164544774238], [3.3715049586348242, 1.4971164544774238], [3.3715049586348242, 1.4971164544774238]], "right kidney": [[831.8568013426873, 14.991866448573818], [4.970270375121704, 3.050385928796316], [4.970270375121704, 3.050385928796316], [4.970270375121704, 3.050385928796316]], "kidney": [[824.7288483151449, 17.740666994112335], [5.134294543833492, 3.188304874790919], [5.134294543833492, 3.188304874790919], [5.134294543833492, 3.188304874790919]], "left kidney": [[765.9269280548916, 14.314482540419498], [5.084499568327313, 3.2061871556243515], [5.084499568327313, 3.2061871556243515], [5.084499568327313, 3.2061871556243515]], "duodenum": [[121.5002253116006, 5.0616837393558045], [13.60882943690214, 15.313999640884173], [13.60882943690214, 15.313999640884173], [13.60882943690214, 15.313999640884173]], "pancreas": [[182.85416969377923, 6.9039775525067135], [17.489564177159146, 14.924761571311656], [17.489564177159146, 14.924761571311656], [17.489564177159146, 14.924761571311656]], "liver (non abdomen window)": [[481.5690096331249, 8.413924027868077], [6.047563882283547, 6.86712354789198], [6.047563882283547, 6.86712354789198], [6.047563882283547, 6.86712354789198]], "liver": [[497.88613290346797, 8.79208581405346], [20.552757782824486, 16.312687320589742], [20.552757782824486, 16.312687320589742], [20.552757782824486, 16.312687320589742]], "spleen": [[496.77984794364835, 8.498216025126785], [14.594250163059534, 10.71357260923987], [14.594250163059534, 10.71357260923987], [14.594250163059534, 10.71357260923987]], "stomach": [[137.7555592980079, 3.928159238756134], [5.978844398494112, 10.238758157160921], [5.978844398494112, 10.238758157160921], [5.978844398494112, 10.238758157160921]], "gallbladder": [[109.56988864543307, 3.4765854683723596], [32.35084093358493, 41.113482214152384], [32.35084093358493, 41.113482214152384], [32.35084093358493, 41.113482214152384]], "left adrenal gland": [[121.60075395406241, 4.266683492995461], [17.017417548383662, 18.48528509828753], [17.017417548383662, 18.48528509828753], [17.017417548383662, 18.48528509828753]], "adrenal gland": [[182.4265613513338, 7.813186080282246], [18.97442893128976, 20.599617257380345], [18.97442893128976, 20.599617257380345], [18.97442893128976, 20.599617257380345]], "right adrenal gland": [[158.21570288963346, 5.736947411814261], [17.17089273745977, 19.09450167978653], [17.17089273745977, 19.09450167978653], [17.17089273745977, 19.09450167978653]], "bladder": [[172.667607742299, 4.6885066612866835], [42.56984081338662, 56.45115036285909], [42.56984081338662, 56.45115036285909], [42.56984081338662, 56.45115036285909]], "esophagus": [[253.86092392814248, 6.886078359154348], [13.252110919965341, 15.437200766467301], [13.252110919965341, 15.437200766467301], [13.252110919965341, 15.437200766467301]]}, "CT-Chest": {"nodule": [[115.14726334918862, 3.0043952160348844], [5.275338876748403, 7.899248653413393], [5.275338876748403, 7.899248653413393], [5.275338876748403, 7.899248653413393]], "COVID-19 infection": [[226.93782607812352, 10.662200522447263], [11.74323002038987, 23.773784082857407], [11.74323002038987, 23.773784082857407], [11.74323002038987, 23.773784082857407]], "tumor": [[81.39154648592063, 3.0363381821985254], [9.799683628807484, 19.248706134279548], [9.799683628807484, 19.248706134279548], [9.799683628807484, 19.248706134279548]]}, "MRI-Abdomen": {"aorta": [[840.9822169946456, 13.699556855062456], [2.9798604461548766, 1.19765659474954], [2.9798604461548766, 1.19765659474954], [2.9798604461548766, 1.19765659474954]], "postcava": [[151.3891903352374, 4.700455115571472], [3.065810750535689, 2.074722812609995], [3.065810750535689, 2.074722812609995], [3.065810750535689, 2.074722812609995]], "right kidney": [[613.4017011464975, 11.282616103318485], [4.63815461741129, 2.2967740371944867], [4.63815461741129, 2.2967740371944867], [4.63815461741129, 2.2967740371944867]], "duodenum": [[88.51851857758399, 5.251374959142798], [9.350910364523573, 8.85976960554745], [9.350910364523573, 8.85976960554745], [9.350910364523573, 8.85976960554745]], "kidney": [[831.5762248415444, 18.739059302777875], [5.715871882386201, 2.6205541393599527], [5.715871882386201, 2.6205541393599527], [5.715871882386201, 2.6205541393599527]], "left kidney": [[255.4744196400276, 5.573793361388763], [6.081920320421431, 2.930383603114708], [6.081920320421431, 2.930383603114708], [6.081920320421431, 2.930383603114708]], "liver": [[491.1931789168259, 9.294627086787225], [10.138029098677139, 6.28829088692463], [10.138029098677139, 6.28829088692463], [10.138029098677139, 6.28829088692463]], "pancreas": [[136.2304629992425, 5.676744286342953], [19.631392824605342, 11.528214201070567], [19.631392824605342, 11.528214201070567], [19.631392824605342, 11.528214201070567]], "gallbladder": [[75.18767252055355, 2.8711737605829892], [14.500831537679415, 20.696868858705496], [14.500831537679415, 20.696868858705496], [14.500831537679415, 20.696868858705496]], "stomach": [[89.16380420023327, 4.461224829090838], [10.266772743753412, 16.943404348738376], [10.266772743753412, 16.943404348738376], [10.266772743753412, 16.943404348738376]], "spleen": [[413.92566589639046, 7.99961594912814], [7.267087388529462, 5.149714876028216], [7.267087388529462, 5.149714876028216], [7.267087388529462, 5.149714876028216]], "left adrenal gland": [[86.44109991236728, 4.826813402237061], [17.153928230900817, 14.858036650050408], [17.153928230900817, 14.858036650050408], [17.153928230900817, 14.858036650050408]], "adrenal gland": [[303.9642820935704, 16.729857009916806], [19.500678047021523, 17.02588768312544], [19.500678047021523, 17.02588768312544], [19.500678047021523, 17.02588768312544]], "right adrenal gland": [[172.36803145644578, 8.050377438528958], [15.257519917725558, 13.431078702905772], [15.257519917725558, 13.431078702905772], [15.257519917725558, 13.431078702905772]], "esophagus": [[193.1348898340059, 7.6397334220243325], [12.240331385391299, 16.812971132953354], [12.240331385391299, 16.812971132953354], [12.240331385391299, 16.812971132953354]]}, "MRI-Cardiac": {"left heart ventricle": [[964.9072936969454, 17.21177762137991], [5.880290818671821, 4.100959742819713], [5.880290818671821, 4.100959742819713], [5.880290818671821, 4.100959742819713]], "myocardium": [[448.3393673888417, 17.591805257426998], [5.208511169313307, 15.910705163394415], [5.208511169313307, 15.910705163394415], [5.208511169313307, 15.910705163394415]], "right heart ventricle": [[359.88937669636215, 9.392153523781843], [5.924076424141962, 5.554667293878979], [5.924076424141962, 5.554667293878979], [5.924076424141962, 5.554667293878979]]}, "MRI-FLAIR-Brain": {"edema": [[69.4159007224176, 5.568921766085619], [13.400334168570177, 4.965265405638592], [13.400334168570177, 4.965265405638592], [13.400334168570177, 4.965265405638592]], "tumor core": [[154.26935124167449, 8.089254912853598], [14.908340542645478, 4.820086393609397], [14.908340542645478, 4.820086393609397], [14.908340542645478, 4.820086393609397]], "whole tumor": [[485.48717118600956, 16.01178236475156], [25.74323915508559, 8.636438181178145], [25.74323915508559, 8.636438181178145], [25.74323915508559, 8.636438181178145]]}, "MRI-T1-Gd-Brain": {"enhancing tumor": [[175.6437881777937, 7.539344668413025], [17.864705093992068, 5.36432831714689], [17.864705093992068, 5.36432831714689], [17.864705093992068, 5.36432831714689]], "non-enhancing tumor": [[37.6625733247702, 3.8454536110058246], [6.568014639412233, 8.446289690167484], [6.568014639412233, 8.446289690167484], [6.568014639412233, 8.446289690167484]], "tumor core": [[180.88223552813486, 6.610443841067055], [9.70294999498087, 5.30262880784197], [9.70294999498087, 5.30262880784197], [9.70294999498087, 5.30262880784197]]}, "Pathology": {"connective tissue cells": [[46.71165884847293, 4.997126203483956], [9.942495884846476, 15.700775443760845], [4.328453739888501, 18.42621798468577], [9.798096322131162, 11.920352021312304]], "inflammatory cells": [[39.600337990197595, 3.1848025413959706], [6.287418328538852, 20.538379638162322], [2.9521703595392146, 25.264465092284006], [6.559595490616054, 12.004686961917436]], "neoplastic cells": [[82.29374052289526, 8.22429924322936], [9.592296798563375, 14.818916788142138], [4.948629785308088, 19.78516221506478], [10.729094314024243, 12.934345198477494]], "epithelial cells": [[91.75183574899573, 9.577544361042948], [13.469843493323452, 27.305962287612964], [4.696928248406198, 25.254143364646463], [11.077634907582583, 13.487595094752443]]}, "X-Ray-Chest": {"left lung": [[529.1669758355144, 7.465035502868491], [8.220284641505614, 11.62958600654364], [8.220284641505614, 11.62958600654364], [8.220284641505614, 11.62958600654364]], "lung": [[465.7809501354513, 7.147122106450173], [8.781306299078446, 12.335455073688102], [8.781306299078446, 12.335455073688102], [8.781306299078446, 12.335455073688102]], "right lung": [[567.6127039725319, 7.532428563004494], [8.067311420424144, 11.229763331648746], [8.067311420424144, 11.229763331648746], [8.067311420424144, 11.229763331648746]]}, "Ultrasound-Cardiac": {"left heart atrium": [[1188.687550702627, 24.234766943758856], [5.18832820435626, 13.705576921752291], [5.18832820435626, 13.705576921752291], [5.18832820435626, 13.705576921752291]], "left heart ventricle": [[2787.334986695437, 58.297232816307506], [15.28158405889985, 56.95469460140377], [15.28158405889985, 56.95469460140377], [15.28158405889985, 56.95469460140377]]}, "Endoscopy": {"neoplastic polyp": [[392.89875472390315, 5.4678888279040745], [7.477729277754545, 1.6522601344780465], [7.2704247484339035, 6.347521355120636], [4.3902399436060335, 6.543658310376327]], "polyp": [[163.7838288028474, 3.4851615302599117], [7.03659746479883, 1.9088902542177986], [6.992807172875011, 6.756628353721484], [5.185761648208865, 8.977427344868255]], "non-neoplastic polyp": [[214.9199548332033, 4.360826895414348], [7.303363948417486, 1.9789835935004905], [10.54652900087687, 9.009706115553772], [6.917879576439251, 10.404634951284532]]}, "Fundus": {"optic cup": [[1482.9561484784422, 35.78105120937013], [52.1031548324398, 1.5080077510381715], [10.023538467761934, 3.1641925551155046], [3.394564722036805, 2.4391933423559626]], "optic disc": [[626.9141229495486, 20.95002931507066], [18.278454005466408, 1.8261365514325893], [16.42282430959315, 11.171338052048034], [4.8937792939550135, 6.987302868644637]]}, "Dermoscopy": {"lesion": [[134.43456931870887, 4.743684855379663], [5.18053578956456, 2.3527492367343634], [3.809383004477107, 6.368793378843402], [2.3888068456218847, 6.655396307215968]], "melanoma": [[454.17848530764076, 9.6466178116726], [4.022144360826467, 7.870140640677671], [4.87109613458874, 18.93721534855073], [3.107895746664011, 13.604075970992069]]}, "OCT": {"edema": [[260.11475018501574, 7.379315940573871], [4.162158474003, 17.437425953761988], [12.65808078622105, 81.37165793634547], [1.763378481483125, 4.427309203795247]]}}
         | 
    	
        main.py
    ADDED
    
    | @@ -0,0 +1,106 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import gradio as gr
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from PIL import Image
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            import matplotlib.pyplot as plt
         | 
| 7 | 
            +
            from huggingface_hub import hf_hub_download
         | 
| 8 | 
            +
            from modeling.BaseModel import BaseModel
         | 
| 9 | 
            +
            from modeling import build_model
         | 
| 10 | 
            +
            from utilities.distributed import init_distributed
         | 
| 11 | 
            +
            from utilities.arguments import load_opt_from_config_files
         | 
| 12 | 
            +
            from utilities.constants import BIOMED_CLASSES
         | 
| 13 | 
            +
            from inference_utils.inference import interactive_infer_image
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def overlay_masks(image, masks, colors):
         | 
| 17 | 
            +
                overlay = image.copy()
         | 
| 18 | 
            +
                overlay = np.array(overlay, dtype=np.uint8)
         | 
| 19 | 
            +
                for mask, color in zip(masks, colors):
         | 
| 20 | 
            +
                    overlay[mask > 0] = (overlay[mask > 0] * 0.4 + np.array(color) * 0.6).astype(
         | 
| 21 | 
            +
                        np.uint8
         | 
| 22 | 
            +
                    )
         | 
| 23 | 
            +
                return Image.fromarray(overlay)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            def generate_colors(n):
         | 
| 27 | 
            +
                cmap = plt.get_cmap("tab10")
         | 
| 28 | 
            +
                colors = [tuple(int(255 * val) for val in cmap(i)[:3]) for i in range(n)]
         | 
| 29 | 
            +
                return colors
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            def init_model():
         | 
| 33 | 
            +
                # Download model
         | 
| 34 | 
            +
                model_file = hf_hub_download(
         | 
| 35 | 
            +
                    repo_id="microsoft/BiomedParse",
         | 
| 36 | 
            +
                    filename="biomedparse_v1.pt",
         | 
| 37 | 
            +
                    token=os.getenv("HF_TOKEN"),
         | 
| 38 | 
            +
                )
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                # Initialize model
         | 
| 41 | 
            +
                conf_files = "configs/biomedparse_inference.yaml"
         | 
| 42 | 
            +
                opt = load_opt_from_config_files([conf_files])
         | 
| 43 | 
            +
                opt = init_distributed(opt)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                model = BaseModel(opt, build_model(opt)).from_pretrained(model_file).eval().cuda()
         | 
| 46 | 
            +
                with torch.no_grad():
         | 
| 47 | 
            +
                    model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(
         | 
| 48 | 
            +
                        BIOMED_CLASSES + ["background"], is_eval=True
         | 
| 49 | 
            +
                    )
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                return model
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            def predict(image, prompts):
         | 
| 55 | 
            +
                if not prompts:
         | 
| 56 | 
            +
                    return None
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                # Convert string input to list
         | 
| 59 | 
            +
                prompts = [p.strip() for p in prompts.split(",")]
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                # Convert to RGB if needed
         | 
| 62 | 
            +
                if image.mode != "RGB":
         | 
| 63 | 
            +
                    image = image.convert("RGB")
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                # Get predictions
         | 
| 66 | 
            +
                pred_mask = interactive_infer_image(model, image, prompts)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                # Generate visualization
         | 
| 69 | 
            +
                colors = generate_colors(len(prompts))
         | 
| 70 | 
            +
                pred_overlay = overlay_masks(
         | 
| 71 | 
            +
                    image, [1 * (pred_mask[i] > 0.5) for i in range(len(prompts))], colors
         | 
| 72 | 
            +
                )
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                return pred_overlay
         | 
| 75 | 
            +
             | 
| 76 | 
            +
             | 
| 77 | 
            +
            def run():
         | 
| 78 | 
            +
                global model
         | 
| 79 | 
            +
                model = init_model()
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                demo = gr.Interface(
         | 
| 82 | 
            +
                    fn=predict,
         | 
| 83 | 
            +
                    inputs=[
         | 
| 84 | 
            +
                        gr.Image(type="pil", label="Input Image"),
         | 
| 85 | 
            +
                        gr.Textbox(
         | 
| 86 | 
            +
                            label="Prompts",
         | 
| 87 | 
            +
                            placeholder="Enter prompts separated by commas (e.g., neoplastic cells, inflammatory cells)",
         | 
| 88 | 
            +
                        ),
         | 
| 89 | 
            +
                    ],
         | 
| 90 | 
            +
                    outputs=gr.Image(type="pil", label="Prediction"),
         | 
| 91 | 
            +
                    title="BiomedParse Demo",
         | 
| 92 | 
            +
                    description="Upload a biomedical image and enter prompts (separated by commas) to detect specific features.",
         | 
| 93 | 
            +
                    examples=[
         | 
| 94 | 
            +
                        [
         | 
| 95 | 
            +
                            "examples/Part_1_516_pathology_breast.png",
         | 
| 96 | 
            +
                            "neoplastic cells, inflammatory cells",
         | 
| 97 | 
            +
                        ]
         | 
| 98 | 
            +
                    ],
         | 
| 99 | 
            +
                )
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                demo.launch(server_name="0.0.0.0", server_port=7860)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
             | 
| 104 | 
            +
            if __name__ == "__main__":
         | 
| 105 | 
            +
                print(f"HF_TOKEN={os.getenv('HF_TOKEN')}")
         | 
| 106 | 
            +
                run()
         | 
    	
        modeling/BaseModel.py
    ADDED
    
    | @@ -0,0 +1,45 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import logging
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn as nn
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from utilities.model import align_and_update_state_dicts
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from utilities.distributed import init_distributed
         | 
| 10 | 
            +
            from utilities.arguments import load_opt_from_config_files
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import huggingface_hub
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            class BaseModel(nn.Module):
         | 
| 18 | 
            +
                def __init__(self, opt, module: nn.Module):
         | 
| 19 | 
            +
                    super(BaseModel, self).__init__()
         | 
| 20 | 
            +
                    self.opt = opt
         | 
| 21 | 
            +
                    self.model = module
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                def forward(self, *inputs, **kwargs):
         | 
| 24 | 
            +
                    outputs = self.model(*inputs, **kwargs)
         | 
| 25 | 
            +
                    return outputs
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                def save_pretrained(self, save_dir):
         | 
| 28 | 
            +
                    torch.save(self.model.state_dict(), os.path.join(save_dir, "model_state_dict.pt"))
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                def from_pretrained(self, pretrained, filename: str = "biomedparse_v1.pt",
         | 
| 31 | 
            +
                                    local_dir: str = "./pretrained", config_dir: str = "./configs"):
         | 
| 32 | 
            +
                    if pretrained.startswith("hf_hub:"):
         | 
| 33 | 
            +
                        hub_name = pretrained.split(":")[1]
         | 
| 34 | 
            +
                        huggingface_hub.hf_hub_download(hub_name, filename=filename, 
         | 
| 35 | 
            +
                                                        local_dir=local_dir)
         | 
| 36 | 
            +
                        huggingface_hub.hf_hub_download(hub_name, filename="config.yaml", 
         | 
| 37 | 
            +
                                                        local_dir=config_dir)
         | 
| 38 | 
            +
                        load_dir = os.path.join(local_dir, filename)
         | 
| 39 | 
            +
                    else:
         | 
| 40 | 
            +
                        load_dir = pretrained
         | 
| 41 | 
            +
                    
         | 
| 42 | 
            +
                    state_dict = torch.load(load_dir, map_location=self.opt['device'])
         | 
| 43 | 
            +
                    state_dict = align_and_update_state_dicts(self.model.state_dict(), state_dict)
         | 
| 44 | 
            +
                    self.model.load_state_dict(state_dict, strict=False)
         | 
| 45 | 
            +
                    return self
         | 
    	
        modeling/__init__.py
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            from .architectures import build_model
         | 
    	
        modeling/architectures/__init__.py
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .xdecoder_model import *
         | 
| 2 | 
            +
            from .seem_model_v0 import *
         | 
| 3 | 
            +
            from .seem_model_v1 import *
         | 
| 4 | 
            +
            from .seem_model_demo import *
         | 
| 5 | 
            +
            from .build import build_model
         | 
    	
        modeling/architectures/build.py
    ADDED
    
    | @@ -0,0 +1,22 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            _model_entrypoints = {}
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            def build_model(config, **kwargs):
         | 
| 5 | 
            +
                model_name = config['MODEL']['NAME']
         | 
| 6 | 
            +
             | 
| 7 | 
            +
                if not is_model(model_name):
         | 
| 8 | 
            +
                    raise ValueError(f'Unkown model: {model_name}')
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                return model_entrypoints(model_name)(config, **kwargs)
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            def register_model(fn):
         | 
| 13 | 
            +
                module_name_split = fn.__module__.split('.')
         | 
| 14 | 
            +
                model_name = module_name_split[-1]
         | 
| 15 | 
            +
                _model_entrypoints[model_name] = fn
         | 
| 16 | 
            +
                return fn
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            def model_entrypoints(model_name):
         | 
| 19 | 
            +
                return _model_entrypoints[model_name]
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            def is_model(model_name):
         | 
| 22 | 
            +
                return model_name in _model_entrypoints
         | 
    	
        modeling/architectures/seem_model_demo.py
    ADDED
    
    | @@ -0,0 +1,923 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # --------------------------------------------------------
         | 
| 2 | 
            +
            # SEEM -- Segment Everything Everywhere All at Once
         | 
| 3 | 
            +
            # Licensed under The Apache License 2.0 [see LICENSE for details]
         | 
| 4 | 
            +
            # Written by Xueyan Zou (xueyan@cs.wisc.edu)
         | 
| 5 | 
            +
            # --------------------------------------------------------
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import random
         | 
| 8 | 
            +
            from typing import Tuple
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import numpy as np
         | 
| 11 | 
            +
            import torch
         | 
| 12 | 
            +
            from torch import nn
         | 
| 13 | 
            +
            from torch.nn import functional as F
         | 
| 14 | 
            +
            from kornia.contrib import distance_transform
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from detectron2.structures import Boxes, ImageList, Instances, BitMasks
         | 
| 17 | 
            +
            from detectron2.utils.memory import retry_if_cuda_oom
         | 
| 18 | 
            +
            from detectron2.data import MetadataCatalog
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from .build import register_model
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from ..utils import configurable, get_class_names, get_iou
         | 
| 23 | 
            +
            from ..vision.backbone import build_backbone, Backbone
         | 
| 24 | 
            +
            from ..body import build_xdecoder_head
         | 
| 25 | 
            +
            from ..modules import sem_seg_postprocess, SetCriterion, HungarianMatcher, bbox_postprocess
         | 
| 26 | 
            +
            from ..language import build_language_encoder
         | 
| 27 | 
            +
            from ..language.loss import vl_similarity
         | 
| 28 | 
            +
            from utilities.prompt_engineering import prompt_engineering
         | 
| 29 | 
            +
            from utilities.constants import COCO_PANOPTIC_CLASSES
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            class GeneralizedSEEM(nn.Module):
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                @configurable
         | 
| 35 | 
            +
                def __init__(
         | 
| 36 | 
            +
                    self,
         | 
| 37 | 
            +
                    *,
         | 
| 38 | 
            +
                    backbone: Backbone,
         | 
| 39 | 
            +
                    sem_seg_head: nn.Module,
         | 
| 40 | 
            +
                    criterion: nn.Module,
         | 
| 41 | 
            +
                    losses: dict,
         | 
| 42 | 
            +
                    num_queries: int,
         | 
| 43 | 
            +
                    object_mask_threshold: float,
         | 
| 44 | 
            +
                    overlap_threshold: float,
         | 
| 45 | 
            +
                    metadata,
         | 
| 46 | 
            +
                    task_switch: dict,
         | 
| 47 | 
            +
                    phrase_prob: float,
         | 
| 48 | 
            +
                    size_divisibility: int,
         | 
| 49 | 
            +
                    sem_seg_postprocess_before_inference: bool,
         | 
| 50 | 
            +
                    pixel_mean: Tuple[float],
         | 
| 51 | 
            +
                    pixel_std: Tuple[float],
         | 
| 52 | 
            +
                    # inference
         | 
| 53 | 
            +
                    semantic_on: bool,
         | 
| 54 | 
            +
                    panoptic_on: bool,
         | 
| 55 | 
            +
                    instance_on: bool,
         | 
| 56 | 
            +
                    test_topk_per_image: int,
         | 
| 57 | 
            +
                    train_dataset_name: str,
         | 
| 58 | 
            +
                    interactive_mode: str,
         | 
| 59 | 
            +
                    interactive_iter: str,
         | 
| 60 | 
            +
                    dilation_kernel: torch.Tensor,
         | 
| 61 | 
            +
                ):
         | 
| 62 | 
            +
                    super().__init__()
         | 
| 63 | 
            +
                    self.backbone = backbone
         | 
| 64 | 
            +
                    self.sem_seg_head = sem_seg_head
         | 
| 65 | 
            +
                    self.criterion = criterion
         | 
| 66 | 
            +
                    self.losses = losses
         | 
| 67 | 
            +
                    self.num_queries = num_queries
         | 
| 68 | 
            +
                    self.overlap_threshold = overlap_threshold
         | 
| 69 | 
            +
                    self.object_mask_threshold = object_mask_threshold
         | 
| 70 | 
            +
                    self.metadata = metadata
         | 
| 71 | 
            +
                    if size_divisibility < 0:
         | 
| 72 | 
            +
                        # use backbone size_divisibility if not set
         | 
| 73 | 
            +
                        size_divisibility = self.backbone.size_divisibility
         | 
| 74 | 
            +
                    self.size_divisibility = size_divisibility
         | 
| 75 | 
            +
                    self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
         | 
| 76 | 
            +
                    self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
         | 
| 77 | 
            +
                    self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    # additional args
         | 
| 80 | 
            +
                    self.semantic_on = semantic_on
         | 
| 81 | 
            +
                    self.instance_on = instance_on
         | 
| 82 | 
            +
                    self.panoptic_on = panoptic_on
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    # caption argument
         | 
| 85 | 
            +
                    self.task_switch = task_switch
         | 
| 86 | 
            +
                    self.phrase_prob = phrase_prob
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    self.test_topk_per_image = test_topk_per_image
         | 
| 89 | 
            +
                    self.train_class_names = None
         | 
| 90 | 
            +
                    self.interactive_mode = interactive_mode
         | 
| 91 | 
            +
                    self.interactive_iter = interactive_iter
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    if not self.semantic_on:
         | 
| 94 | 
            +
                        assert self.sem_seg_postprocess_before_inference
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    self.register_buffer("dilation_kernel", dilation_kernel)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                @classmethod
         | 
| 99 | 
            +
                def from_config(cls, cfg):
         | 
| 100 | 
            +
                    enc_cfg = cfg['MODEL']['ENCODER']
         | 
| 101 | 
            +
                    dec_cfg = cfg['MODEL']['DECODER']
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    openimage_switch = {'grounding': dec_cfg['OPENIMAGE']['GROUNDING'].get('ENABLED', False),
         | 
| 104 | 
            +
                                        'mask': dec_cfg['OPENIMAGE'].get('ENABLED', False)}
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    task_switch = {'bbox': dec_cfg.get('DETECTION', False),
         | 
| 107 | 
            +
                                   'mask': dec_cfg.get('MASK', True),
         | 
| 108 | 
            +
                                   'spatial': dec_cfg['SPATIAL'].get('ENABLED', False),
         | 
| 109 | 
            +
                                   'grounding': dec_cfg['GROUNDING'].get('ENABLED', False),
         | 
| 110 | 
            +
                                   'openimage': openimage_switch,
         | 
| 111 | 
            +
                                   'visual': dec_cfg['VISUAL'].get('ENABLED', False),
         | 
| 112 | 
            +
                                   'audio': dec_cfg['AUDIO'].get('ENABLED', False)}
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    # build model
         | 
| 115 | 
            +
                    extra = {'task_switch': task_switch}
         | 
| 116 | 
            +
                    backbone = build_backbone(cfg)
         | 
| 117 | 
            +
                    lang_encoder = build_language_encoder(cfg)        
         | 
| 118 | 
            +
                    sem_seg_head = build_xdecoder_head(cfg, backbone.output_shape(), lang_encoder, extra=extra)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    # Training Settings.
         | 
| 121 | 
            +
                    loss_weights = {}
         | 
| 122 | 
            +
                    matcher = None
         | 
| 123 | 
            +
                    losses = {}
         | 
| 124 | 
            +
                    weight_dict = {}
         | 
| 125 | 
            +
                    grd_weight = {}
         | 
| 126 | 
            +
                    top_x_layers = {}
         | 
| 127 | 
            +
                    criterion = None
         | 
| 128 | 
            +
                    train_dataset_name = None
         | 
| 129 | 
            +
                    phrase_prob = None
         | 
| 130 | 
            +
                    # Loss parameters:
         | 
| 131 | 
            +
                    deep_supervision = None
         | 
| 132 | 
            +
                    no_object_weight = None
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    interactive_mode = 'best'
         | 
| 135 | 
            +
                    interactive_iter = 20
         | 
| 136 | 
            +
                    dilation = 3
         | 
| 137 | 
            +
                    dilation_kernel = torch.ones((1, 1, dilation, dilation), device=torch.cuda.current_device())
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    return {
         | 
| 140 | 
            +
                        "backbone": backbone,
         | 
| 141 | 
            +
                        "sem_seg_head": sem_seg_head,
         | 
| 142 | 
            +
                        "criterion": criterion,
         | 
| 143 | 
            +
                        "losses": losses,
         | 
| 144 | 
            +
                        "num_queries": dec_cfg['NUM_OBJECT_QUERIES'],
         | 
| 145 | 
            +
                        "object_mask_threshold": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'],
         | 
| 146 | 
            +
                        "overlap_threshold": dec_cfg['TEST']['OVERLAP_THRESHOLD'],
         | 
| 147 | 
            +
                        "metadata": None,
         | 
| 148 | 
            +
                        "size_divisibility": dec_cfg['SIZE_DIVISIBILITY'],
         | 
| 149 | 
            +
                        "sem_seg_postprocess_before_inference": (
         | 
| 150 | 
            +
                            dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE']
         | 
| 151 | 
            +
                            or dec_cfg['TEST']['PANOPTIC_ON']
         | 
| 152 | 
            +
                            or dec_cfg['TEST']['INSTANCE_ON']
         | 
| 153 | 
            +
                        ),
         | 
| 154 | 
            +
                        "pixel_mean": cfg['INPUT']['PIXEL_MEAN'],
         | 
| 155 | 
            +
                        "pixel_std": cfg['INPUT']['PIXEL_STD'],
         | 
| 156 | 
            +
                        "task_switch": task_switch,
         | 
| 157 | 
            +
                        "phrase_prob": phrase_prob,
         | 
| 158 | 
            +
                        # inference
         | 
| 159 | 
            +
                        "semantic_on": dec_cfg['TEST']['SEMANTIC_ON'],
         | 
| 160 | 
            +
                        "instance_on": dec_cfg['TEST']['INSTANCE_ON'],
         | 
| 161 | 
            +
                        "panoptic_on": dec_cfg['TEST']['PANOPTIC_ON'],
         | 
| 162 | 
            +
                        "test_topk_per_image": cfg['MODEL']['DECODER']['TEST']['DETECTIONS_PER_IMAGE'],
         | 
| 163 | 
            +
                        "train_dataset_name": train_dataset_name,
         | 
| 164 | 
            +
                        "interactive_mode": interactive_mode,
         | 
| 165 | 
            +
                        "interactive_iter": interactive_iter,
         | 
| 166 | 
            +
                        "dilation_kernel": dilation_kernel,
         | 
| 167 | 
            +
                    }
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                @property
         | 
| 170 | 
            +
                def device(self):
         | 
| 171 | 
            +
                    return self.pixel_mean.device
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                def forward(self, batched_inputs, mode='default'):
         | 
| 174 | 
            +
                    if self.training:
         | 
| 175 | 
            +
                        losses = {}
         | 
| 176 | 
            +
                        if self.task_switch['mask']:
         | 
| 177 | 
            +
                            losses_seg = self.forward_seg(batched_inputs)
         | 
| 178 | 
            +
                            losses.update(losses_seg)
         | 
| 179 | 
            +
                        if self.task_switch['openimage'] and self.task_switch['openimage']['mask']:
         | 
| 180 | 
            +
                            losses_openimage = self.forward_openimage(batched_inputs['openimage'])
         | 
| 181 | 
            +
                            losses_openimage = {key.replace('mask', 'openimage'):value for key, value in losses_openimage.items()}
         | 
| 182 | 
            +
                            losses_openimage = {key.replace('grounding', 'grounding_openimage'):value for key, value in losses_openimage.items()}
         | 
| 183 | 
            +
                            losses.update(losses_openimage)
         | 
| 184 | 
            +
                        for k in list(losses.keys()):
         | 
| 185 | 
            +
                            if k in self.criterion.weight_dict:
         | 
| 186 | 
            +
                                losses[k] *= self.criterion.weight_dict[k]
         | 
| 187 | 
            +
                            else: # remove this loss if not specified in `weight_dict`
         | 
| 188 | 
            +
                                losses.pop(k)
         | 
| 189 | 
            +
                        return losses
         | 
| 190 | 
            +
                    else:
         | 
| 191 | 
            +
                        if mode == 'interactive':
         | 
| 192 | 
            +
                            return self.evaluate_interactive(batched_inputs)
         | 
| 193 | 
            +
                        elif mode == 'grounding_spatial':
         | 
| 194 | 
            +
                            return self.evaluate_grounding_sptial(batched_inputs, mode)
         | 
| 195 | 
            +
                        elif mode in ['grounding_phrasecut', 'grounding_refcoco']:
         | 
| 196 | 
            +
                            return self.evaluate_grounding(batched_inputs, mode)
         | 
| 197 | 
            +
                        else:
         | 
| 198 | 
            +
                            return self.evaluate(batched_inputs)
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    
         | 
| 201 | 
            +
                def forward_seg(self, batched_inputs):
         | 
| 202 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 203 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 204 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                    self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(self.train_class_names, is_eval=False)
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    extra = {}
         | 
| 209 | 
            +
                    # mask classification target
         | 
| 210 | 
            +
                    if "instances" in batched_inputs[0]:
         | 
| 211 | 
            +
                        # input bounding box is checked to be correct.
         | 
| 212 | 
            +
                        targets = self.prepare_targets(batched_inputs, images)
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                        if self.task_switch['grounding']:
         | 
| 215 | 
            +
                            grounding_tokens = [x['grounding_query_embs'] for x in targets] # need to pad for more than one grounding token
         | 
| 216 | 
            +
                            grounding_tokens = nn.utils.rnn.pad_sequence(grounding_tokens, padding_value=-1)
         | 
| 217 | 
            +
                            non_zero_query_mask = (grounding_tokens.sum(dim=-1) == -grounding_tokens.shape[-1])
         | 
| 218 | 
            +
                            grounding_tokens[non_zero_query_mask] = 0
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                            extra['grounding_tokens'] = grounding_tokens
         | 
| 221 | 
            +
                            extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                        if self.task_switch['spatial']:
         | 
| 224 | 
            +
                            pos_masks = [x['spatial_query']['rand_shape'].to(self.device) for x in batched_inputs]
         | 
| 225 | 
            +
                            neg_masks = [(x['spatial_query']['rand_shape'].to(self.device) & False) for x in batched_inputs]
         | 
| 226 | 
            +
                            fp_masks = torch.stack([(x['spatial_query']['rand_shape'].to(self.device) & False) for x in batched_inputs])
         | 
| 227 | 
            +
                            extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks, 'false_positive_mask': fp_masks})
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    features = self.backbone(images.tensor)
         | 
| 230 | 
            +
                    mask_features, _, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                    # forward spatial only without gradient
         | 
| 233 | 
            +
                    if self.task_switch['spatial']:
         | 
| 234 | 
            +
                        with torch.no_grad():
         | 
| 235 | 
            +
                            # generate random integeter between [0,3]
         | 
| 236 | 
            +
                            rand_iter_num = random.randint(0, 2)
         | 
| 237 | 
            +
                            for i in range(rand_iter_num):
         | 
| 238 | 
            +
                                outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, extra=extra, task='spatial')
         | 
| 239 | 
            +
                                extra.update(outputs)
         | 
| 240 | 
            +
                                extra.update(self.prepare_next_spaital_mask(extra, batched_inputs))
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, extra=extra, task='seg')
         | 
| 243 | 
            +
                    extra = {'lang_logit': self.sem_seg_head.predictor.lang_encoder.logit_scale,
         | 
| 244 | 
            +
                             'class_embeddings': getattr(self.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('default')),
         | 
| 245 | 
            +
                             'false_positive_mask': extra['false_positive_mask']}
         | 
| 246 | 
            +
                    # bipartite matching-based loss
         | 
| 247 | 
            +
                    self.criterion.losses = self.losses['seg'] # seg criterion losses
         | 
| 248 | 
            +
                    losses = self.criterion(outputs, targets, extra)
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                    del outputs
         | 
| 251 | 
            +
                    return losses
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                def evaluate_demo(self, batched_inputs):
         | 
| 254 | 
            +
                    assert len(batched_inputs) == 1, "only support batch size equal to 1"
         | 
| 255 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 256 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 257 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 258 | 
            +
                    img_bs = images.tensor.shape[0]
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                    targets = targets_grounding = queries_grounding = None
         | 
| 261 | 
            +
                    features = self.backbone(images.tensor)
         | 
| 262 | 
            +
                    mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
         | 
| 263 | 
            +
                    image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                    extra = {}
         | 
| 266 | 
            +
                    if 'stroke' in batched_inputs[0]:
         | 
| 267 | 
            +
                        pos_masks = (batched_inputs[0]['stroke'].to(self.device)).unbind(0)
         | 
| 268 | 
            +
                        pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
         | 
| 269 | 
            +
                        neg_masks = (batched_inputs[0]['stroke'].to(self.device) & False).unbind(0)
         | 
| 270 | 
            +
                        neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
         | 
| 271 | 
            +
                        extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    if 'visual' in batched_inputs[0]:
         | 
| 274 | 
            +
                        extra.update(batched_inputs[0]['visual'])
         | 
| 275 | 
            +
                    
         | 
| 276 | 
            +
                    if 'text' in batched_inputs[0]:
         | 
| 277 | 
            +
                        gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(batched_inputs[0]['text'], name='grounding', token=False, norm=False)
         | 
| 278 | 
            +
                        token_emb = gtext['token_emb']
         | 
| 279 | 
            +
                        tokens = gtext['tokens']
         | 
| 280 | 
            +
                        query_emb = token_emb[tokens['attention_mask'].bool()]
         | 
| 281 | 
            +
                        non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
         | 
| 282 | 
            +
                        extra['grounding_tokens'] = query_emb[:,None]
         | 
| 283 | 
            +
                        extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
         | 
| 284 | 
            +
                        extra['grounding_class'] = gtext['class_emb']
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    if 'audio' in batched_inputs[0]:
         | 
| 287 | 
            +
                        gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(batched_inputs[0]['audio'], name='grounding', token=False, norm=False)
         | 
| 288 | 
            +
                        token_emb = gtext['token_emb']
         | 
| 289 | 
            +
                        tokens = gtext['tokens']
         | 
| 290 | 
            +
                        query_emb = token_emb[tokens['attention_mask'].bool()]
         | 
| 291 | 
            +
                        non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
         | 
| 292 | 
            +
                        extra['audio_tokens'] = query_emb[:,None]
         | 
| 293 | 
            +
                        extra['audio_nonzero_mask'] = non_zero_query_mask.t()
         | 
| 294 | 
            +
                        extra['audio_class'] = gtext['class_emb']
         | 
| 295 | 
            +
                    
         | 
| 296 | 
            +
                    outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='demo')
         | 
| 297 | 
            +
                    return outputs, images.tensor.shape, extra
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                    assert self.task_switch['spatial']
         | 
| 300 | 
            +
                    assert 'spatial_query' in batched_inputs[0]
         | 
| 301 | 
            +
                    assert len(batched_inputs) == 1, "only support batch size equal to 1"
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 304 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 305 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 306 | 
            +
                    img_bs = images.tensor.shape[0]
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                    targets = targets_grounding = queries_grounding = None
         | 
| 309 | 
            +
                    extra = {}
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                    features = self.backbone(images.tensor)
         | 
| 312 | 
            +
                    mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                    image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
         | 
| 315 | 
            +
                    nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
         | 
| 316 | 
            +
                    multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
         | 
| 317 | 
            +
                    mask_features = mask_features.repeat(nm,1,1,1)
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                    all_batch_shape_iou = []
         | 
| 320 | 
            +
                    pred_smask_pointer = None
         | 
| 321 | 
            +
                    prev_smask_pointer = None
         | 
| 322 | 
            +
                    pred_smask_all = None
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                    query_index = self.sem_seg_head.predictor.query_index
         | 
| 325 | 
            +
                    assert self.interactive_mode == 'best'
         | 
| 326 | 
            +
                    pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
         | 
| 327 | 
            +
                    pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                    neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
         | 
| 330 | 
            +
                    neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
         | 
| 331 | 
            +
                    extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                    for i in range(self.interactive_iter):
         | 
| 334 | 
            +
                        outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
         | 
| 335 | 
            +
                        extra.update(outputs)
         | 
| 336 | 
            +
                        pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bicubic')
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                        s = image_sizes[0]
         | 
| 339 | 
            +
                        b = batched_inputs[0]
         | 
| 340 | 
            +
                        pred_smask_all = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bicubic')[:,0].sigmoid() > 0.5
         | 
| 341 | 
            +
                        gt_smask = b['gt_masks_orisize']
         | 
| 342 | 
            +
                        all_batch_shape_iou += [get_iou(gt_smask, pred_smask_all)]
         | 
| 343 | 
            +
                        extra.update(self.prepare_next_spaital_mask(extra, batched_inputs))
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                    all_batch_shape_iou = torch.stack(all_batch_shape_iou)
         | 
| 346 | 
            +
                    processed_results = [{"mask_iou": all_batch_shape_iou[:,i]} for i in range(len(all_batch_shape_iou[0]))]
         | 
| 347 | 
            +
                    return processed_results
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                def evaluate(self, batched_inputs):
         | 
| 350 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 351 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 352 | 
            +
                    
         | 
| 353 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 354 | 
            +
                    img_bs = images.tensor.shape[0]
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                    targets = targets_grounding = queries_grounding = None
         | 
| 357 | 
            +
                    features = self.backbone(images.tensor)
         | 
| 358 | 
            +
                    outputs = self.sem_seg_head(features, target_queries=queries_grounding)
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                    mask_cls_results = outputs["pred_logits"]
         | 
| 361 | 
            +
                    mask_pred_results = outputs["pred_masks"]
         | 
| 362 | 
            +
                    box_pred_results = outputs["pred_boxes"] if self.task_switch['bbox'] else [None for i in range(len(mask_pred_results))]
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                    # upsample masks
         | 
| 365 | 
            +
                    mask_pred_results = F.interpolate(
         | 
| 366 | 
            +
                        mask_pred_results,
         | 
| 367 | 
            +
                        size=(images.tensor.shape[-2], images.tensor.shape[-1]),
         | 
| 368 | 
            +
                        mode="bilinear",
         | 
| 369 | 
            +
                        align_corners=False,
         | 
| 370 | 
            +
                    )
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                    input_size = mask_pred_results.shape[-2:]
         | 
| 373 | 
            +
                    del outputs
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                    processed_results = []
         | 
| 376 | 
            +
                    for mask_cls_result, mask_pred_result, box_pred_result, input_per_image, image_size in zip(
         | 
| 377 | 
            +
                        mask_cls_results, mask_pred_results, box_pred_results, batched_inputs, images.image_sizes
         | 
| 378 | 
            +
                    ):
         | 
| 379 | 
            +
                        height = input_per_image.get("height", image_size[0])
         | 
| 380 | 
            +
                        width = input_per_image.get("width", image_size[1])
         | 
| 381 | 
            +
                        processed_results.append({})
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                        if self.sem_seg_postprocess_before_inference:
         | 
| 384 | 
            +
                            mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
         | 
| 385 | 
            +
                                mask_pred_result, image_size, height, width
         | 
| 386 | 
            +
                            )
         | 
| 387 | 
            +
                            mask_cls_result = mask_cls_result.to(mask_pred_result)
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                        # semantic segmentation inference
         | 
| 390 | 
            +
                        if self.semantic_on:
         | 
| 391 | 
            +
                            r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result)
         | 
| 392 | 
            +
                            if not self.sem_seg_postprocess_before_inference:
         | 
| 393 | 
            +
                                r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width)
         | 
| 394 | 
            +
                            processed_results[-1]["sem_seg"] = r
         | 
| 395 | 
            +
             | 
| 396 | 
            +
                        # panoptic segmentation inference
         | 
| 397 | 
            +
                        if self.panoptic_on:
         | 
| 398 | 
            +
                            panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result)
         | 
| 399 | 
            +
                            processed_results[-1]["panoptic_seg"] = panoptic_r
         | 
| 400 | 
            +
                        
         | 
| 401 | 
            +
                        # instance segmentation inference
         | 
| 402 | 
            +
                        if self.instance_on:
         | 
| 403 | 
            +
                            if self.task_switch['bbox']:
         | 
| 404 | 
            +
                                box_pred_result = bbox_postprocess(box_pred_result, input_size, image_size, height, width)
         | 
| 405 | 
            +
                            instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result, box_pred_result)
         | 
| 406 | 
            +
                            processed_results[-1]["instances"] = instance_r
         | 
| 407 | 
            +
             | 
| 408 | 
            +
                    return processed_results
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                def evaluate_interactive(self, batched_inputs):
         | 
| 411 | 
            +
                    assert self.task_switch['spatial']
         | 
| 412 | 
            +
                    assert 'spatial_query' in batched_inputs[0]
         | 
| 413 | 
            +
                    assert len(batched_inputs) == 1, "only support batch size equal to 1"
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 416 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 417 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 418 | 
            +
                    img_bs = images.tensor.shape[0]
         | 
| 419 | 
            +
             | 
| 420 | 
            +
                    targets = targets_grounding = queries_grounding = None
         | 
| 421 | 
            +
                    extra = {}
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                    features = self.backbone(images.tensor)
         | 
| 424 | 
            +
                    mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                    image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
         | 
| 427 | 
            +
                    nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
         | 
| 428 | 
            +
                    multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
         | 
| 429 | 
            +
                    mask_features = mask_features.repeat(nm,1,1,1)
         | 
| 430 | 
            +
             | 
| 431 | 
            +
                    all_batch_shape_iou = []
         | 
| 432 | 
            +
                    pred_smask_pointer = None
         | 
| 433 | 
            +
                    prev_smask_pointer = None
         | 
| 434 | 
            +
                    pred_smask_all = None
         | 
| 435 | 
            +
             | 
| 436 | 
            +
                    query_index = self.sem_seg_head.predictor.query_index
         | 
| 437 | 
            +
                    assert self.interactive_mode == 'best'
         | 
| 438 | 
            +
                    pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
         | 
| 439 | 
            +
                    pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                    neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
         | 
| 442 | 
            +
                    neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
         | 
| 443 | 
            +
                    extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
         | 
| 444 | 
            +
             | 
| 445 | 
            +
                    for i in range(self.interactive_iter):
         | 
| 446 | 
            +
                        outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
         | 
| 447 | 
            +
                        extra.update(outputs)
         | 
| 448 | 
            +
                        pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bicubic')
         | 
| 449 | 
            +
             | 
| 450 | 
            +
                        s = image_sizes[0]
         | 
| 451 | 
            +
                        b = batched_inputs[0]
         | 
| 452 | 
            +
                        pred_smask_all = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bicubic')[:,0].sigmoid() > 0.5
         | 
| 453 | 
            +
                        gt_smask = b['gt_masks_orisize']
         | 
| 454 | 
            +
                        all_batch_shape_iou += [get_iou(gt_smask, pred_smask_all)]
         | 
| 455 | 
            +
                        extra.update(self.prepare_next_spaital_mask(extra, batched_inputs))
         | 
| 456 | 
            +
             | 
| 457 | 
            +
                    all_batch_shape_iou = torch.stack(all_batch_shape_iou)
         | 
| 458 | 
            +
                    processed_results = [{"mask_iou": all_batch_shape_iou[:,i]} for i in range(len(all_batch_shape_iou[0]))]
         | 
| 459 | 
            +
                    return processed_results
         | 
| 460 | 
            +
             | 
| 461 | 
            +
                def evaluate_referring_image(self, batched_inputs, extra={}):
         | 
| 462 | 
            +
                    assert self.task_switch['spatial']
         | 
| 463 | 
            +
                    assert len(batched_inputs) == 1, "only support batch size equal to 1"
         | 
| 464 | 
            +
                    assert self.interactive_mode == 'best'
         | 
| 465 | 
            +
             | 
| 466 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 467 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 468 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 469 | 
            +
                    img_bs = images.tensor.shape[0]
         | 
| 470 | 
            +
             | 
| 471 | 
            +
                    targets = targets_grounding = queries_grounding = None
         | 
| 472 | 
            +
                    features = self.backbone(images.tensor)
         | 
| 473 | 
            +
                    mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
         | 
| 474 | 
            +
             | 
| 475 | 
            +
                    if 'spatial_query' in batched_inputs[0]:
         | 
| 476 | 
            +
                        image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
         | 
| 477 | 
            +
                        nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
         | 
| 478 | 
            +
                        multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
         | 
| 479 | 
            +
                        mask_features = mask_features.repeat(nm,1,1,1)
         | 
| 480 | 
            +
             | 
| 481 | 
            +
                        query_index = self.sem_seg_head.predictor.query_index
         | 
| 482 | 
            +
                        pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
         | 
| 483 | 
            +
                        pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
         | 
| 484 | 
            +
             | 
| 485 | 
            +
                        neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
         | 
| 486 | 
            +
                        neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
         | 
| 487 | 
            +
                        extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                    outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='refimg')
         | 
| 490 | 
            +
                    return outputs, images.tensor.shape
         | 
| 491 | 
            +
             | 
| 492 | 
            +
                def evaluate_grounding(self, batched_inputs, mode):
         | 
| 493 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 494 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 495 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 496 | 
            +
                    assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
         | 
| 497 | 
            +
             | 
| 498 | 
            +
                    extra = {}
         | 
| 499 | 
            +
                    # mask_pred_results = []
         | 
| 500 | 
            +
                    # for idx, batch_per_image in enumerate(batched_inputs):
         | 
| 501 | 
            +
                    #     grd_texts = batch_per_image['groundings']['texts']
         | 
| 502 | 
            +
                    #     grd_masks = []
         | 
| 503 | 
            +
                    #     for anno_text in grd_texts:
         | 
| 504 | 
            +
                    #         gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
         | 
| 505 | 
            +
                    #         token_emb = gtext['token_emb']
         | 
| 506 | 
            +
                    #         tokens = gtext['tokens']
         | 
| 507 | 
            +
                        
         | 
| 508 | 
            +
                    #         grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
         | 
| 509 | 
            +
                    #         extra['grounding_tokens'] = grd_emb[:,None]
         | 
| 510 | 
            +
             | 
| 511 | 
            +
                    #         assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
         | 
| 512 | 
            +
                    #         features = self.backbone(images.tensor)
         | 
| 513 | 
            +
                    #         outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
         | 
| 514 | 
            +
                            
         | 
| 515 | 
            +
                    #         pred_gmasks = outputs['pred_masks'][idx,self.num_queries:2*self.num_queries-1]
         | 
| 516 | 
            +
                    #         v_emb = outputs['pred_captions'][idx,self.num_queries:2*self.num_queries-1]
         | 
| 517 | 
            +
                    #         t_emb = grd_emb[-1:]
         | 
| 518 | 
            +
             | 
| 519 | 
            +
                    #         t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 520 | 
            +
                    #         v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)            
         | 
| 521 | 
            +
             | 
| 522 | 
            +
                    #         temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
         | 
| 523 | 
            +
                    #         out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
         | 
| 524 | 
            +
                            
         | 
| 525 | 
            +
                    #         matched_id = out_prob.max(0)[1]
         | 
| 526 | 
            +
                    #         grd_masks += [pred_gmasks[matched_id,:,:]]
         | 
| 527 | 
            +
                    #     mask_pred_results += [torch.cat(grd_masks)]
         | 
| 528 | 
            +
             | 
| 529 | 
            +
                    # comment for multi object inference.
         | 
| 530 | 
            +
                    mask_pred_results = []
         | 
| 531 | 
            +
                    for idx, batch_per_image in enumerate(batched_inputs):
         | 
| 532 | 
            +
                        grd_texts = batch_per_image['groundings']['texts']
         | 
| 533 | 
            +
                        grd_texts = [x[0] for x in grd_texts]
         | 
| 534 | 
            +
             | 
| 535 | 
            +
                        gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
         | 
| 536 | 
            +
                        token_emb = gtext['token_emb']
         | 
| 537 | 
            +
                        tokens = gtext['tokens']
         | 
| 538 | 
            +
                        query_emb = token_emb[tokens['attention_mask'].bool()]
         | 
| 539 | 
            +
                        non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
         | 
| 540 | 
            +
             | 
| 541 | 
            +
                        extra['grounding_tokens'] = query_emb[:,None]
         | 
| 542 | 
            +
                        extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
         | 
| 543 | 
            +
             | 
| 544 | 
            +
                        features = self.backbone(images.tensor)
         | 
| 545 | 
            +
                        outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
         | 
| 546 | 
            +
             | 
| 547 | 
            +
                        pred_gmasks = outputs['pred_gmasks'][idx]
         | 
| 548 | 
            +
                        v_emb = outputs['pred_gtexts'][idx]
         | 
| 549 | 
            +
                        t_emb = gtext['class_emb']
         | 
| 550 | 
            +
             | 
| 551 | 
            +
                        t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 552 | 
            +
                        v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)            
         | 
| 553 | 
            +
             | 
| 554 | 
            +
                        temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
         | 
| 555 | 
            +
                        out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
         | 
| 556 | 
            +
                        
         | 
| 557 | 
            +
                        matched_id = out_prob.max(0)[1]
         | 
| 558 | 
            +
                        mask_pred_results += [pred_gmasks[matched_id,:,:]]
         | 
| 559 | 
            +
             | 
| 560 | 
            +
                    for i in range(len(mask_pred_results)):
         | 
| 561 | 
            +
                        # upsample masks
         | 
| 562 | 
            +
                        mask_pred_results[i] = F.interpolate(
         | 
| 563 | 
            +
                            mask_pred_results[i][None,],
         | 
| 564 | 
            +
                            size=(images.tensor.shape[-2], images.tensor.shape[-1]),
         | 
| 565 | 
            +
                            mode="bilinear",
         | 
| 566 | 
            +
                            align_corners=False,
         | 
| 567 | 
            +
                        )[0]
         | 
| 568 | 
            +
             | 
| 569 | 
            +
                    processed_results = []
         | 
| 570 | 
            +
                    for mask_pred_result, input_per_image, image_size in zip(
         | 
| 571 | 
            +
                        mask_pred_results, batched_inputs, images.image_sizes
         | 
| 572 | 
            +
                    ):
         | 
| 573 | 
            +
                        height = input_per_image.get("height", image_size[0])
         | 
| 574 | 
            +
                        width = input_per_image.get("width", image_size[1])
         | 
| 575 | 
            +
                        processed_results.append({})
         | 
| 576 | 
            +
             | 
| 577 | 
            +
                        mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
         | 
| 578 | 
            +
                            mask_pred_result, image_size, height, width
         | 
| 579 | 
            +
                        )
         | 
| 580 | 
            +
                        processed_results[-1]['grounding_mask'] = mask_pred_result
         | 
| 581 | 
            +
             | 
| 582 | 
            +
                        # compute bbox
         | 
| 583 | 
            +
                        # bbox = BitMasks(mask_pred_result > 0).get_bounding_boxes()
         | 
| 584 | 
            +
                        # bbox = BoxMode.convert(bbox.tensor, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
         | 
| 585 | 
            +
                        # processed_results[-1]['grounding_box'] = bbox
         | 
| 586 | 
            +
             | 
| 587 | 
            +
                    return processed_results
         | 
| 588 | 
            +
             | 
| 589 | 
            +
                def evaluate_grounding_sptial(self, batched_inputs, mode):
         | 
| 590 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 591 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 592 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 593 | 
            +
                    assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
         | 
| 594 | 
            +
             | 
| 595 | 
            +
                    extra = {}
         | 
| 596 | 
            +
                    dilation = 3
         | 
| 597 | 
            +
                    pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
         | 
| 598 | 
            +
                    pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
         | 
| 599 | 
            +
                    pos_masks = (F.conv2d(pos_masks.float(), self.dilation_kernel, padding=dilation//2) > 0).unbind(0)
         | 
| 600 | 
            +
             | 
| 601 | 
            +
                    neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
         | 
| 602 | 
            +
                    neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
         | 
| 603 | 
            +
             | 
| 604 | 
            +
                    mask_pred_results = []
         | 
| 605 | 
            +
                    for idx, batch_per_image in enumerate(batched_inputs):
         | 
| 606 | 
            +
                        grd_texts = batch_per_image['groundings']['texts']
         | 
| 607 | 
            +
                        grd_masks = []
         | 
| 608 | 
            +
                        for idx2, anno_text in enumerate(grd_texts):
         | 
| 609 | 
            +
                            extra.update({'spatial_query_pos_mask': [pos_masks[idx2]], 'spatial_query_neg_mask': [neg_masks[idx2]]})
         | 
| 610 | 
            +
             | 
| 611 | 
            +
                            gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
         | 
| 612 | 
            +
                            token_emb = gtext['token_emb']
         | 
| 613 | 
            +
                            tokens = gtext['tokens']
         | 
| 614 | 
            +
                        
         | 
| 615 | 
            +
                            grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
         | 
| 616 | 
            +
                            non_zero_query_mask = torch.zeros(grd_emb[:,None].shape[:-1], dtype=torch.bool, device=grd_emb.device)
         | 
| 617 | 
            +
                            extra['grounding_tokens'] = grd_emb[:,None]
         | 
| 618 | 
            +
                            extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
         | 
| 619 | 
            +
             | 
| 620 | 
            +
                            assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
         | 
| 621 | 
            +
                            features = self.backbone(images.tensor)
         | 
| 622 | 
            +
                            outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
         | 
| 623 | 
            +
             | 
| 624 | 
            +
                            pred_gmasks = outputs['pred_gmasks'][idx]
         | 
| 625 | 
            +
                            v_emb = outputs['pred_gtexts'][idx]
         | 
| 626 | 
            +
                            t_emb = gtext['class_emb']
         | 
| 627 | 
            +
             | 
| 628 | 
            +
                            t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 629 | 
            +
                            v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)            
         | 
| 630 | 
            +
             | 
| 631 | 
            +
                            temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
         | 
| 632 | 
            +
                            out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
         | 
| 633 | 
            +
                            
         | 
| 634 | 
            +
                            matched_id = out_prob.max(0)[1]
         | 
| 635 | 
            +
                            grd_masks += [pred_gmasks[matched_id,:,:]]
         | 
| 636 | 
            +
                        mask_pred_results += [torch.cat(grd_masks)]
         | 
| 637 | 
            +
             | 
| 638 | 
            +
                    # comment for multi object inference.
         | 
| 639 | 
            +
                    # mask_pred_results = []
         | 
| 640 | 
            +
                    # for idx, batch_per_image in enumerate(batched_inputs):
         | 
| 641 | 
            +
                    #     grd_texts = batch_per_image['groundings']['texts']
         | 
| 642 | 
            +
                    #     grd_texts = [x[0] for x in grd_texts]
         | 
| 643 | 
            +
             | 
| 644 | 
            +
                    #     gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
         | 
| 645 | 
            +
                    #     token_emb = gtext['token_emb']
         | 
| 646 | 
            +
                    #     tokens = gtext['tokens']
         | 
| 647 | 
            +
                    #     query_emb = token_emb[tokens['attention_mask'].bool()]
         | 
| 648 | 
            +
                    #     non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
         | 
| 649 | 
            +
             | 
| 650 | 
            +
                    #     extra['grounding_tokens'] = query_emb[:,None]
         | 
| 651 | 
            +
                    #     extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
         | 
| 652 | 
            +
             | 
| 653 | 
            +
                    #     features = self.backbone(images.tensor)
         | 
| 654 | 
            +
                    #     outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
         | 
| 655 | 
            +
             | 
| 656 | 
            +
                    #     pred_gmasks = outputs['pred_gmasks'][idx]
         | 
| 657 | 
            +
                    #     v_emb = outputs['pred_gtexts'][idx]
         | 
| 658 | 
            +
                    #     t_emb = gtext['class_emb']
         | 
| 659 | 
            +
             | 
| 660 | 
            +
                    #     t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 661 | 
            +
                    #     v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)            
         | 
| 662 | 
            +
             | 
| 663 | 
            +
                    #     temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
         | 
| 664 | 
            +
                    #     out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
         | 
| 665 | 
            +
                        
         | 
| 666 | 
            +
                    #     matched_id = out_prob.max(0)[1]
         | 
| 667 | 
            +
                    #     mask_pred_results += [pred_gmasks[matched_id,:,:]]
         | 
| 668 | 
            +
             | 
| 669 | 
            +
                    for i in range(len(mask_pred_results)):
         | 
| 670 | 
            +
                        # upsample masks
         | 
| 671 | 
            +
                        mask_pred_results[i] = F.interpolate(
         | 
| 672 | 
            +
                            mask_pred_results[i][None,],
         | 
| 673 | 
            +
                            size=(images.tensor.shape[-2], images.tensor.shape[-1]),
         | 
| 674 | 
            +
                            mode="bilinear",
         | 
| 675 | 
            +
                            align_corners=False,
         | 
| 676 | 
            +
                        )[0]
         | 
| 677 | 
            +
             | 
| 678 | 
            +
                    processed_results = []
         | 
| 679 | 
            +
                    for mask_pred_result, input_per_image, image_size in zip(
         | 
| 680 | 
            +
                        mask_pred_results, batched_inputs, images.image_sizes
         | 
| 681 | 
            +
                    ):
         | 
| 682 | 
            +
                        height = input_per_image.get("height", image_size[0])
         | 
| 683 | 
            +
                        width = input_per_image.get("width", image_size[1])
         | 
| 684 | 
            +
                        processed_results.append({})
         | 
| 685 | 
            +
             | 
| 686 | 
            +
                        mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
         | 
| 687 | 
            +
                            mask_pred_result, image_size, height, width
         | 
| 688 | 
            +
                        )
         | 
| 689 | 
            +
                        processed_results[-1]['grounding_mask'] = mask_pred_result
         | 
| 690 | 
            +
             | 
| 691 | 
            +
                    return processed_results
         | 
| 692 | 
            +
             | 
| 693 | 
            +
                def prepare_targets(self, batched_inputs, images):
         | 
| 694 | 
            +
                    h_pad, w_pad = images.tensor.shape[-2:]
         | 
| 695 | 
            +
                    new_targets = []
         | 
| 696 | 
            +
                    for idx, batch_per_image in enumerate(batched_inputs):
         | 
| 697 | 
            +
                        targets_per_image = batch_per_image['instances'].to(self.device)
         | 
| 698 | 
            +
                        # pad gt
         | 
| 699 | 
            +
                        gt_masks = targets_per_image.gt_masks.tensor
         | 
| 700 | 
            +
                        padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
         | 
| 701 | 
            +
                        padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
         | 
| 702 | 
            +
             | 
| 703 | 
            +
                        gt_boxes = targets_per_image.gt_boxes.tensor
         | 
| 704 | 
            +
                        ratio = torch.tensor([w_pad,h_pad,w_pad,h_pad]).to(gt_boxes.device)[None,:]
         | 
| 705 | 
            +
                        gt_boxes = gt_boxes / ratio
         | 
| 706 | 
            +
                        xc,yc,w,h = (gt_boxes[:,0] + gt_boxes[:,2])/2, (gt_boxes[:,1] + gt_boxes[:,3])/2, gt_boxes[:,2] - gt_boxes[:,0], gt_boxes[:,3] - gt_boxes[:,1]
         | 
| 707 | 
            +
                        gt_boxes = torch.stack([xc,yc,w,h]).permute(1,0)
         | 
| 708 | 
            +
             | 
| 709 | 
            +
                        target_dict = {
         | 
| 710 | 
            +
                                "labels": targets_per_image.gt_classes,
         | 
| 711 | 
            +
                                "is_things": targets_per_image.is_things,
         | 
| 712 | 
            +
                                "masks": padded_masks,
         | 
| 713 | 
            +
                                "boxes": gt_boxes,
         | 
| 714 | 
            +
                                }
         | 
| 715 | 
            +
             | 
| 716 | 
            +
                        if self.task_switch['spatial']:
         | 
| 717 | 
            +
                            # prepare targets for spatial query
         | 
| 718 | 
            +
                            target_dict['gt_spatial_masks'] = batch_per_image['spatial_query']['gt_masks']
         | 
| 719 | 
            +
             | 
| 720 | 
            +
                        if self.task_switch['grounding']:
         | 
| 721 | 
            +
                            grd_masks = batch_per_image['groundings']['masks']
         | 
| 722 | 
            +
                            grd_texts = batch_per_image['groundings']['texts']
         | 
| 723 | 
            +
                            grd_hash = batch_per_image['groundings']['hash']
         | 
| 724 | 
            +
                            grd_task = batch_per_image['groundings']['mode']
         | 
| 725 | 
            +
                            
         | 
| 726 | 
            +
                            if len(grd_masks) == 0:
         | 
| 727 | 
            +
                                padded_masks = None
         | 
| 728 | 
            +
                            else:
         | 
| 729 | 
            +
                                padded_masks = torch.zeros((grd_masks.shape[0], h_pad, w_pad), dtype=grd_masks.dtype, device=grd_masks.device)
         | 
| 730 | 
            +
                                padded_masks[:, : grd_masks.shape[1], : grd_masks.shape[2]] = grd_masks
         | 
| 731 | 
            +
             | 
| 732 | 
            +
                            gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
         | 
| 733 | 
            +
                            token_emb = gtext['token_emb']
         | 
| 734 | 
            +
                            tokens = gtext['tokens']
         | 
| 735 | 
            +
                            
         | 
| 736 | 
            +
                            unique_hash_id = np.unique(grd_hash, return_index=True)[1]
         | 
| 737 | 
            +
                            selected_mask = np.zeros(len(grd_hash)).astype(bool)
         | 
| 738 | 
            +
                            selected_mask[unique_hash_id] = True
         | 
| 739 | 
            +
             | 
| 740 | 
            +
                            selected_token_emb = token_emb[selected_mask]
         | 
| 741 | 
            +
                            selected_attn_mask = tokens['attention_mask'][selected_mask]
         | 
| 742 | 
            +
                            query_emb = selected_token_emb[selected_attn_mask.bool()]
         | 
| 743 | 
            +
                            
         | 
| 744 | 
            +
                            class_idx = tokens['attention_mask'].sum(dim=-1) - 1
         | 
| 745 | 
            +
                            class_idx = torch.stack((torch.arange(len(class_idx), device=class_idx.device), class_idx)).tolist()
         | 
| 746 | 
            +
                            class_emb = token_emb[class_idx]
         | 
| 747 | 
            +
                            
         | 
| 748 | 
            +
                            target_dict['grounding_masks'] = padded_masks
         | 
| 749 | 
            +
                            target_dict['grounding_query_embs'] = query_emb
         | 
| 750 | 
            +
                            target_dict['grounding_class_embs'] = class_emb
         | 
| 751 | 
            +
                            target_dict['grounding_hash'] = grd_hash
         | 
| 752 | 
            +
                            target_dict['grounding_task'] = grd_task
         | 
| 753 | 
            +
             | 
| 754 | 
            +
                        new_targets.append(target_dict)
         | 
| 755 | 
            +
                    return new_targets
         | 
| 756 | 
            +
             | 
| 757 | 
            +
                def prepare_next_spaital_mask(self, outputs, batched_inputs):
         | 
| 758 | 
            +
                    gt_masks = [batched_inputs[i]['spatial_query']['gt_masks'] for i in range(len(batched_inputs))]
         | 
| 759 | 
            +
                    if self.training:
         | 
| 760 | 
            +
                        gt_masks = ImageList.from_tensors(gt_masks, self.size_divisibility).tensor
         | 
| 761 | 
            +
                    else:
         | 
| 762 | 
            +
                        gt_masks = ImageList.from_tensors(gt_masks, self.size_divisibility).tensor.transpose(0,1)
         | 
| 763 | 
            +
             | 
| 764 | 
            +
                    pred_masks = (F.interpolate(outputs['prev_mask'], size=gt_masks.shape[-2:], mode='bilinear', align_corners=False).sigmoid() > 0.5)
         | 
| 765 | 
            +
                    prev_masks = torch.stack(outputs['spatial_query_pos_mask']) | torch.stack(outputs['spatial_query_neg_mask'])
         | 
| 766 | 
            +
             | 
| 767 | 
            +
                    fn = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks) # fn: False Negative, gt:1, pred:0, prev:0
         | 
| 768 | 
            +
                    fp = (~gt_masks & pred_masks) & (~prev_masks) # fp: False Positive, gt:0, pred:1, prev:0
         | 
| 769 | 
            +
             | 
| 770 | 
            +
                    # compute iou between gt and pred
         | 
| 771 | 
            +
                    iou = (gt_masks & pred_masks).sum(list(range(1,len(fn.shape)))) / ((gt_masks | pred_masks).sum(dim=list(range(1,len(fn.shape)))) + 1e-8)
         | 
| 772 | 
            +
                    fn_sum = fn.sum(dim=list(range(1,len(fn.shape))))
         | 
| 773 | 
            +
                    fp_sum = fp.sum(dim=list(range(1,len(fp.shape))))
         | 
| 774 | 
            +
             | 
| 775 | 
            +
                    is_postive = fn_sum > fp_sum
         | 
| 776 | 
            +
                    # is_postive = torch.ones(len(fn_sum), device=torch.cuda.current_device()).bool()
         | 
| 777 | 
            +
                    select_mask = torch.stack([fn[i] if is_postive[i] else fp[i] for i in range(len(fn))])
         | 
| 778 | 
            +
             | 
| 779 | 
            +
                    # conv implementation
         | 
| 780 | 
            +
                    n,_,h,w=select_mask.shape
         | 
| 781 | 
            +
                    mask_dt = (distance_transform((~F.pad(select_mask, pad=(1, 1, 1, 1), mode='constant', value=0)).float())[:,:,1:-1,1:-1]).reshape(n,-1)
         | 
| 782 | 
            +
                    max_xy_idx = torch.stack([torch.arange(n), mask_dt.max(dim=-1)[1].cpu()]).tolist()
         | 
| 783 | 
            +
                    next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
         | 
| 784 | 
            +
                    next_mask = next_mask.view(n,-1)
         | 
| 785 | 
            +
                    next_mask[max_xy_idx] = True
         | 
| 786 | 
            +
                    next_mask = next_mask.reshape((n,1,h,w)).float()
         | 
| 787 | 
            +
                    dilation = 3
         | 
| 788 | 
            +
                    next_mask = F.conv2d(next_mask, self.dilation_kernel, padding=dilation//2) > 0
         | 
| 789 | 
            +
             | 
| 790 | 
            +
                    # determine whether next mask is zero
         | 
| 791 | 
            +
                    keep = (iou < 0.925)
         | 
| 792 | 
            +
                    next_mask = next_mask & keep.view(-1,1,1,1)
         | 
| 793 | 
            +
             | 
| 794 | 
            +
                    pos_mask = []
         | 
| 795 | 
            +
                    neg_mask = []
         | 
| 796 | 
            +
                    for idx, ip in enumerate(is_postive):
         | 
| 797 | 
            +
                        if ip:
         | 
| 798 | 
            +
                            pos_mask += [outputs['spatial_query_pos_mask'][idx] | next_mask[idx]]
         | 
| 799 | 
            +
                            neg_mask += [outputs['spatial_query_neg_mask'][idx]]
         | 
| 800 | 
            +
                        else:
         | 
| 801 | 
            +
                            pos_mask += [outputs['spatial_query_pos_mask'][idx]]
         | 
| 802 | 
            +
                            neg_mask += [outputs['spatial_query_neg_mask'][idx] | next_mask[idx]]
         | 
| 803 | 
            +
                    
         | 
| 804 | 
            +
                    if 'false_positive_mask' in outputs:
         | 
| 805 | 
            +
                        fp = outputs['false_positive_mask'] | fp
         | 
| 806 | 
            +
                    return {'spatial_query_pos_mask': pos_mask, 'spatial_query_neg_mask': neg_mask, 'false_positive_mask': fp}
         | 
| 807 | 
            +
             | 
| 808 | 
            +
                def semantic_inference(self, mask_cls, mask_pred):
         | 
| 809 | 
            +
                    mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
         | 
| 810 | 
            +
                    mask_pred = mask_pred.sigmoid()
         | 
| 811 | 
            +
                    semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
         | 
| 812 | 
            +
                    return semseg
         | 
| 813 | 
            +
             | 
| 814 | 
            +
                def panoptic_inference(self, mask_cls, mask_pred):
         | 
| 815 | 
            +
                    scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
         | 
| 816 | 
            +
                    mask_pred = mask_pred.sigmoid()
         | 
| 817 | 
            +
             | 
| 818 | 
            +
                    keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
         | 
| 819 | 
            +
                    cur_scores = scores[keep]
         | 
| 820 | 
            +
                    cur_classes = labels[keep]
         | 
| 821 | 
            +
                    cur_masks = mask_pred[keep]
         | 
| 822 | 
            +
                    cur_mask_cls = mask_cls[keep]
         | 
| 823 | 
            +
                    cur_mask_cls = cur_mask_cls[:, :-1]
         | 
| 824 | 
            +
             | 
| 825 | 
            +
                    cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
         | 
| 826 | 
            +
             | 
| 827 | 
            +
                    h, w = cur_masks.shape[-2:]
         | 
| 828 | 
            +
                    panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
         | 
| 829 | 
            +
                    segments_info = []
         | 
| 830 | 
            +
             | 
| 831 | 
            +
                    current_segment_id = 0
         | 
| 832 | 
            +
             | 
| 833 | 
            +
                    if cur_masks.shape[0] == 0:
         | 
| 834 | 
            +
                        # We didn't detect any mask :(
         | 
| 835 | 
            +
                        return panoptic_seg, segments_info
         | 
| 836 | 
            +
                    else:
         | 
| 837 | 
            +
                        # take argmax
         | 
| 838 | 
            +
                        cur_mask_ids = cur_prob_masks.argmax(0)
         | 
| 839 | 
            +
                        stuff_memory_list = {}
         | 
| 840 | 
            +
                        for k in range(cur_classes.shape[0]):
         | 
| 841 | 
            +
                            pred_class = cur_classes[k].item()
         | 
| 842 | 
            +
                            isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values()
         | 
| 843 | 
            +
                            mask_area = (cur_mask_ids == k).sum().item()
         | 
| 844 | 
            +
                            original_area = (cur_masks[k] >= 0.5).sum().item()
         | 
| 845 | 
            +
                            mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
         | 
| 846 | 
            +
             | 
| 847 | 
            +
                            if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
         | 
| 848 | 
            +
                                if mask_area / original_area < self.overlap_threshold:
         | 
| 849 | 
            +
                                    continue
         | 
| 850 | 
            +
             | 
| 851 | 
            +
                                # merge stuff regions
         | 
| 852 | 
            +
                                if not isthing:
         | 
| 853 | 
            +
                                    if int(pred_class) in stuff_memory_list.keys():
         | 
| 854 | 
            +
                                        panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
         | 
| 855 | 
            +
                                        continue
         | 
| 856 | 
            +
                                    else:
         | 
| 857 | 
            +
                                        stuff_memory_list[int(pred_class)] = current_segment_id + 1
         | 
| 858 | 
            +
             | 
| 859 | 
            +
                                current_segment_id += 1
         | 
| 860 | 
            +
                                panoptic_seg[mask] = current_segment_id
         | 
| 861 | 
            +
             | 
| 862 | 
            +
                                segments_info.append(
         | 
| 863 | 
            +
                                    {
         | 
| 864 | 
            +
                                        "id": current_segment_id,
         | 
| 865 | 
            +
                                        "isthing": bool(isthing),
         | 
| 866 | 
            +
                                        "category_id": int(pred_class),
         | 
| 867 | 
            +
                                    }
         | 
| 868 | 
            +
                                )
         | 
| 869 | 
            +
             | 
| 870 | 
            +
                        return panoptic_seg, segments_info
         | 
| 871 | 
            +
             | 
| 872 | 
            +
                def instance_inference(self, mask_cls, mask_pred, box_pred):
         | 
| 873 | 
            +
                    # mask_pred is already processed to have the same shape as original input
         | 
| 874 | 
            +
                    image_size = mask_pred.shape[-2:]
         | 
| 875 | 
            +
             | 
| 876 | 
            +
                    # [Q, K]
         | 
| 877 | 
            +
                    scores = F.softmax(mask_cls, dim=-1)[:, :-1]
         | 
| 878 | 
            +
                    labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
         | 
| 879 | 
            +
                    # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
         | 
| 880 | 
            +
                    scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
         | 
| 881 | 
            +
             | 
| 882 | 
            +
                    labels_per_image = labels[topk_indices]
         | 
| 883 | 
            +
                    topk_indices = (topk_indices // self.sem_seg_head.num_classes)
         | 
| 884 | 
            +
                    # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
         | 
| 885 | 
            +
                    mask_pred = mask_pred[topk_indices]
         | 
| 886 | 
            +
                    if box_pred is not None:
         | 
| 887 | 
            +
                        box_pred = box_pred[topk_indices]
         | 
| 888 | 
            +
             | 
| 889 | 
            +
                    # if this is panoptic segmentation, we only keep the "thing" classes
         | 
| 890 | 
            +
                    if self.panoptic_on:
         | 
| 891 | 
            +
                        keep = torch.zeros_like(scores_per_image).bool()
         | 
| 892 | 
            +
                        for i, lab in enumerate(labels_per_image):
         | 
| 893 | 
            +
                            keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values()
         | 
| 894 | 
            +
             | 
| 895 | 
            +
                        scores_per_image = scores_per_image[keep]
         | 
| 896 | 
            +
                        labels_per_image = labels_per_image[keep]
         | 
| 897 | 
            +
                        mask_pred = mask_pred[keep]
         | 
| 898 | 
            +
             | 
| 899 | 
            +
                        if box_pred is not None:
         | 
| 900 | 
            +
                            box_pred = box_pred[keep]
         | 
| 901 | 
            +
             | 
| 902 | 
            +
                    result = Instances(image_size)
         | 
| 903 | 
            +
                    # mask (before sigmoid)
         | 
| 904 | 
            +
                    result.pred_masks = (mask_pred > 0).float()
         | 
| 905 | 
            +
                    # result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
         | 
| 906 | 
            +
                    # Uncomment the following to get boxes from masks (this is slow)
         | 
| 907 | 
            +
             | 
| 908 | 
            +
                    if box_pred is not None:
         | 
| 909 | 
            +
                        result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
         | 
| 910 | 
            +
                    else:
         | 
| 911 | 
            +
                        result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
         | 
| 912 | 
            +
             | 
| 913 | 
            +
                    # calculate average mask prob
         | 
| 914 | 
            +
                    mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
         | 
| 915 | 
            +
                    result.scores = scores_per_image * mask_scores_per_image
         | 
| 916 | 
            +
                    result.pred_classes = labels_per_image
         | 
| 917 | 
            +
             | 
| 918 | 
            +
                    return result
         | 
| 919 | 
            +
             | 
| 920 | 
            +
             | 
| 921 | 
            +
            @register_model
         | 
| 922 | 
            +
            def get_seem_model(cfg, **kwargs):
         | 
| 923 | 
            +
                return GeneralizedSEEM(cfg)
         | 
    	
        modeling/architectures/seem_model_v0.py
    ADDED
    
    | @@ -0,0 +1,1160 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # --------------------------------------------------------
         | 
| 2 | 
            +
            # SEEM -- Segment Everything Everywhere All at Once
         | 
| 3 | 
            +
            # Licensed under The Apache License 2.0 [see LICENSE for details]
         | 
| 4 | 
            +
            # Written by Xueyan Zou (xueyan@cs.wisc.edu)
         | 
| 5 | 
            +
            # --------------------------------------------------------
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import random
         | 
| 8 | 
            +
            from typing import Tuple
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import numpy as np
         | 
| 11 | 
            +
            import torch
         | 
| 12 | 
            +
            from torch import nn
         | 
| 13 | 
            +
            from torch.nn import functional as F
         | 
| 14 | 
            +
            from kornia.contrib import distance_transform
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from detectron2.structures import Boxes, ImageList, Instances, BitMasks
         | 
| 17 | 
            +
            from detectron2.utils.memory import retry_if_cuda_oom
         | 
| 18 | 
            +
            from detectron2.data import MetadataCatalog
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from .build import register_model
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from ..utils import configurable, get_class_names, get_iou
         | 
| 23 | 
            +
            from ..vision.backbone import build_backbone, Backbone
         | 
| 24 | 
            +
            from ..body import build_xdecoder_head
         | 
| 25 | 
            +
            from ..modules import sem_seg_postprocess, SetCriterion, HungarianMatcher, bbox_postprocess
         | 
| 26 | 
            +
            from ..language import build_language_encoder
         | 
| 27 | 
            +
            from ..language.loss import vl_similarity
         | 
| 28 | 
            +
            from utilities.prompt_engineering import prompt_engineering
         | 
| 29 | 
            +
            from utilities.constants import COCO_PANOPTIC_CLASSES
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            class GeneralizedSEEM(nn.Module):
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                @configurable
         | 
| 35 | 
            +
                def __init__(
         | 
| 36 | 
            +
                    self,
         | 
| 37 | 
            +
                    *,
         | 
| 38 | 
            +
                    backbone: Backbone,
         | 
| 39 | 
            +
                    sem_seg_head: nn.Module,
         | 
| 40 | 
            +
                    criterion: nn.Module,
         | 
| 41 | 
            +
                    losses: dict,
         | 
| 42 | 
            +
                    num_queries: int,
         | 
| 43 | 
            +
                    object_mask_threshold: float,
         | 
| 44 | 
            +
                    overlap_threshold: float,
         | 
| 45 | 
            +
                    metadata,
         | 
| 46 | 
            +
                    task_switch: dict,
         | 
| 47 | 
            +
                    phrase_prob: float,
         | 
| 48 | 
            +
                    size_divisibility: int,
         | 
| 49 | 
            +
                    sem_seg_postprocess_before_inference: bool,
         | 
| 50 | 
            +
                    pixel_mean: Tuple[float],
         | 
| 51 | 
            +
                    pixel_std: Tuple[float],
         | 
| 52 | 
            +
                    # inference
         | 
| 53 | 
            +
                    semantic_on: bool,
         | 
| 54 | 
            +
                    panoptic_on: bool,
         | 
| 55 | 
            +
                    instance_on: bool,
         | 
| 56 | 
            +
                    test_topk_per_image: int,
         | 
| 57 | 
            +
                    train_dataset_name: str,
         | 
| 58 | 
            +
                    interactive_mode: str,
         | 
| 59 | 
            +
                    interactive_iter: str,
         | 
| 60 | 
            +
                    dilation_kernel: torch.Tensor,
         | 
| 61 | 
            +
                    train_max_iter: int,
         | 
| 62 | 
            +
                ):
         | 
| 63 | 
            +
                    """
         | 
| 64 | 
            +
                    Args:
         | 
| 65 | 
            +
                        backbone: a backbone module, must follow detectron2's backbone interface
         | 
| 66 | 
            +
                        sem_seg_head: a module that predicts semantic segmentation from backbone features
         | 
| 67 | 
            +
                        criterion: a module that defines the loss
         | 
| 68 | 
            +
                        num_queries: int, number of queries
         | 
| 69 | 
            +
                        object_mask_threshold: float, threshold to filter query based on classification score
         | 
| 70 | 
            +
                            for panoptic segmentation inference
         | 
| 71 | 
            +
                        overlap_threshold: overlap threshold used in general inference for panoptic segmentation
         | 
| 72 | 
            +
                        metadata: dataset meta, get `thing` and `stuff` category names for panoptic
         | 
| 73 | 
            +
                            segmentation inference
         | 
| 74 | 
            +
                        size_divisibility: Some backbones require the input height and width to be divisible by a
         | 
| 75 | 
            +
                            specific integer. We can use this to override such requirement.
         | 
| 76 | 
            +
                        sem_seg_postprocess_before_inference: whether to resize the prediction back
         | 
| 77 | 
            +
                            to original input size before semantic segmentation inference or after.
         | 
| 78 | 
            +
                            For high-resolution dataset like Mapillary, resizing predictions before
         | 
| 79 | 
            +
                            inference will cause OOM error.
         | 
| 80 | 
            +
                        pixel_mean, pixel_std: list or tuple with #channels element, representing
         | 
| 81 | 
            +
                            the per-channel mean and std to be used to normalize the input image
         | 
| 82 | 
            +
                        semantic_on: bool, whether to output semantic segmentation prediction
         | 
| 83 | 
            +
                        instance_on: bool, whether to output instance segmentation prediction
         | 
| 84 | 
            +
                        panoptic_on: bool, whether to output panoptic segmentation prediction
         | 
| 85 | 
            +
                        test_topk_per_image: int, instance segmentation parameter, keep topk instances per image
         | 
| 86 | 
            +
                    """
         | 
| 87 | 
            +
                    super().__init__()
         | 
| 88 | 
            +
                    self.backbone = backbone
         | 
| 89 | 
            +
                    self.sem_seg_head = sem_seg_head
         | 
| 90 | 
            +
                    self.criterion = criterion
         | 
| 91 | 
            +
                    self.losses = losses
         | 
| 92 | 
            +
                    self.num_queries = num_queries
         | 
| 93 | 
            +
                    self.overlap_threshold = overlap_threshold
         | 
| 94 | 
            +
                    self.object_mask_threshold = object_mask_threshold
         | 
| 95 | 
            +
                    self.metadata = metadata
         | 
| 96 | 
            +
                    if size_divisibility < 0:
         | 
| 97 | 
            +
                        # use backbone size_divisibility if not set
         | 
| 98 | 
            +
                        size_divisibility = self.backbone.size_divisibility
         | 
| 99 | 
            +
                    self.size_divisibility = size_divisibility
         | 
| 100 | 
            +
                    self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
         | 
| 101 | 
            +
                    self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
         | 
| 102 | 
            +
                    self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    # additional args
         | 
| 105 | 
            +
                    self.semantic_on = semantic_on
         | 
| 106 | 
            +
                    self.instance_on = instance_on
         | 
| 107 | 
            +
                    self.panoptic_on = panoptic_on
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    # caption argument
         | 
| 110 | 
            +
                    self.task_switch = task_switch
         | 
| 111 | 
            +
                    self.phrase_prob = phrase_prob
         | 
| 112 | 
            +
                    self.train_max_iter = train_max_iter
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    self.test_topk_per_image = test_topk_per_image
         | 
| 115 | 
            +
                    self.train_class_names = get_class_names(train_dataset_name)
         | 
| 116 | 
            +
                    self.interactive_mode = interactive_mode
         | 
| 117 | 
            +
                    self.interactive_iter = interactive_iter
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    if not self.semantic_on:
         | 
| 120 | 
            +
                        assert self.sem_seg_postprocess_before_inference
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    self.register_buffer("dilation_kernel", dilation_kernel)
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                @classmethod
         | 
| 125 | 
            +
                def from_config(cls, cfg):
         | 
| 126 | 
            +
                    enc_cfg = cfg['MODEL']['ENCODER']
         | 
| 127 | 
            +
                    dec_cfg = cfg['MODEL']['DECODER']
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    # Loss parameters:
         | 
| 130 | 
            +
                    deep_supervision = dec_cfg['DEEP_SUPERVISION']
         | 
| 131 | 
            +
                    no_object_weight = dec_cfg['NO_OBJECT_WEIGHT']
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    # loss weights
         | 
| 134 | 
            +
                    loss_weights = {'mask': {'ce': dec_cfg['CLASS_WEIGHT'], 'dice': dec_cfg['DICE_WEIGHT'], 'bce': dec_cfg['MASK_WEIGHT']},
         | 
| 135 | 
            +
                                    'bbox': {'l1': dec_cfg['BBOX_WEIGHT'], 'giou': dec_cfg['GIOU_WEIGHT']},
         | 
| 136 | 
            +
                                    'spatial': {'ce': dec_cfg['SCLASS_WEIGHT'], 'dice': dec_cfg['SDICE_WEIGHT'], 'bce': dec_cfg['SMASK_WEIGHT']},
         | 
| 137 | 
            +
                                    'grounding': {'ce': dec_cfg['GCLASS_WEIGHT'], 'dice': dec_cfg['GDICE_WEIGHT'], 'bce': dec_cfg['GMASK_WEIGHT']},
         | 
| 138 | 
            +
                                    'openimage': {'ce': dec_cfg['OCLASS_WEIGHT'], 'dice': dec_cfg['ODICE_WEIGHT'], 'bce': dec_cfg['OMASK_WEIGHT']}}
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    openimage_switch = {'grounding': dec_cfg['OPENIMAGE']['GROUNDING'].get('ENABLED', False),
         | 
| 141 | 
            +
                                        'mask': dec_cfg['OPENIMAGE'].get('ENABLED', False)}
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    task_switch = {'bbox': dec_cfg.get('DETECTION', False),
         | 
| 144 | 
            +
                                   'mask': dec_cfg['MASK'].get('ENABLED', True),
         | 
| 145 | 
            +
                                   'spatial': dec_cfg['SPATIAL'].get('ENABLED', False),
         | 
| 146 | 
            +
                                   'grounding': dec_cfg['GROUNDING'].get('ENABLED', False),
         | 
| 147 | 
            +
                                   'openimage': openimage_switch}
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    top_x_layers = {'mask': dec_cfg.get('TOP_MASK_LAYERS', 10),
         | 
| 150 | 
            +
                                    'grounding': dec_cfg.get('TOP_GROUNDING_LAYERS', 10),
         | 
| 151 | 
            +
                                    'openimage': dec_cfg.get('TOP_OPENIMAGE_LAYERS', 10),
         | 
| 152 | 
            +
                                    'spatial': dec_cfg.get('TOP_SPATIAL_LAYERS', 10)}
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    spatial_cost = {"class_weight": dec_cfg['COST_SPATIAL']['CLASS_WEIGHT'],
         | 
| 155 | 
            +
                                    "mask_weight": dec_cfg['COST_SPATIAL']['MASK_WEIGHT'],
         | 
| 156 | 
            +
                                    "dice_weight": dec_cfg['COST_SPATIAL']['DICE_WEIGHT']}
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    extra = {'task_switch': task_switch}
         | 
| 159 | 
            +
                    backbone = build_backbone(cfg)
         | 
| 160 | 
            +
                    lang_encoder = build_language_encoder(cfg)        
         | 
| 161 | 
            +
                    sem_seg_head = build_xdecoder_head(cfg, backbone.output_shape(), lang_encoder, extra=extra)
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    # building criterion
         | 
| 164 | 
            +
                    matcher = HungarianMatcher(
         | 
| 165 | 
            +
                        cost_class=loss_weights['mask']['ce'],
         | 
| 166 | 
            +
                        cost_mask=loss_weights['mask']['bce'],
         | 
| 167 | 
            +
                        cost_dice=loss_weights['mask']['dice'],
         | 
| 168 | 
            +
                        num_points=dec_cfg['TRAIN_NUM_POINTS'],
         | 
| 169 | 
            +
                        spatial_cost=spatial_cost,
         | 
| 170 | 
            +
                    )
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    # init weight dict and criterion loss functions.
         | 
| 173 | 
            +
                    losses = {'seg': [], 'openimage': []}
         | 
| 174 | 
            +
                    if task_switch['mask']:
         | 
| 175 | 
            +
                        losses['seg'] += ["labels", "masks"]
         | 
| 176 | 
            +
                    if task_switch['spatial']:
         | 
| 177 | 
            +
                        losses['seg'] += ["spatials"]
         | 
| 178 | 
            +
                    if task_switch['grounding']:
         | 
| 179 | 
            +
                        losses['seg'] += ["groundings"]
         | 
| 180 | 
            +
                    if task_switch['openimage']:
         | 
| 181 | 
            +
                        losses['openimage'] += ["labels_openimage", "masks"]
         | 
| 182 | 
            +
                    if task_switch['openimage']['grounding']:
         | 
| 183 | 
            +
                        losses['openimage'] += ["groundings"]
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                    weight_dict = {}
         | 
| 186 | 
            +
                    for key, turn_on in task_switch.items():
         | 
| 187 | 
            +
                        if turn_on:
         | 
| 188 | 
            +
                            if isinstance(loss_weights[key], dict):
         | 
| 189 | 
            +
                                # HACK it should support bbox in the future
         | 
| 190 | 
            +
                                for key_, weight in loss_weights[key].items():
         | 
| 191 | 
            +
                                    weight_dict["loss_{}_{}_0".format(key, key_)] = weight # NOTE: hard code for segmentation that has multiple loss
         | 
| 192 | 
            +
                            else:
         | 
| 193 | 
            +
                                weight_dict["loss_{}_0".format(key)] = loss_weights[key]
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    # generate full weight dict and remove not computed layers. 
         | 
| 196 | 
            +
                    if deep_supervision:
         | 
| 197 | 
            +
                        dec_layers = dec_cfg['DEC_LAYERS']
         | 
| 198 | 
            +
                        aux_weight_dict = {}
         | 
| 199 | 
            +
                        for i in range(dec_layers - 1):
         | 
| 200 | 
            +
                            for k, v in weight_dict.items():
         | 
| 201 | 
            +
                                if (i+1) > (top_x_layers[k.split('_')[1]] - 1):
         | 
| 202 | 
            +
                                    continue
         | 
| 203 | 
            +
                                aux_weight_dict.update({k.replace('_0', f"_{i+1}"): v})
         | 
| 204 | 
            +
                        weight_dict.update(aux_weight_dict)
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                    grd_weight = {'text': dec_cfg['GROUNDING']['TEXT_WEIGHT'], 'class': dec_cfg['GROUNDING']['CLASS_WEIGHT']}
         | 
| 207 | 
            +
                    # generate critenrion for loss function.
         | 
| 208 | 
            +
                    criterion = SetCriterion(
         | 
| 209 | 
            +
                        sem_seg_head.num_classes,
         | 
| 210 | 
            +
                        matcher=matcher,
         | 
| 211 | 
            +
                        weight_dict=weight_dict,
         | 
| 212 | 
            +
                        top_x_layers=top_x_layers,
         | 
| 213 | 
            +
                        eos_coef=no_object_weight,
         | 
| 214 | 
            +
                        losses=[],
         | 
| 215 | 
            +
                        num_points=dec_cfg['TRAIN_NUM_POINTS'],
         | 
| 216 | 
            +
                        oversample_ratio=dec_cfg['OVERSAMPLE_RATIO'],
         | 
| 217 | 
            +
                        importance_sample_ratio=dec_cfg['IMPORTANCE_SAMPLE_RATIO'],
         | 
| 218 | 
            +
                        grounding_weight=grd_weight,
         | 
| 219 | 
            +
                    )
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                    # extra logistic
         | 
| 222 | 
            +
                    train_dataset_name = cfg['DATASETS']['TRAIN'][0] # HACK for only one training set.
         | 
| 223 | 
            +
                    train_max_iter = dec_cfg['SPATIAL'].get('MAX_ITER', 3)
         | 
| 224 | 
            +
                    phrase_prob = dec_cfg['CAPTION'].get('PHRASE_PROB', 0.5)
         | 
| 225 | 
            +
                    interactive_mode = cfg['STROKE_SAMPLER']['EVAL']['MODE']
         | 
| 226 | 
            +
                    interactive_iter = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    dilation = 3
         | 
| 229 | 
            +
                    dilation_kernel = torch.ones((1, 1, dilation, dilation), device=torch.cuda.current_device())
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                    return {
         | 
| 232 | 
            +
                        "backbone": backbone,
         | 
| 233 | 
            +
                        "sem_seg_head": sem_seg_head,
         | 
| 234 | 
            +
                        "criterion": criterion,
         | 
| 235 | 
            +
                        "losses": losses,
         | 
| 236 | 
            +
                        "num_queries": dec_cfg['NUM_OBJECT_QUERIES'],
         | 
| 237 | 
            +
                        "object_mask_threshold": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'],
         | 
| 238 | 
            +
                        "overlap_threshold": dec_cfg['TEST']['OVERLAP_THRESHOLD'],
         | 
| 239 | 
            +
                        "metadata": MetadataCatalog.get(cfg['DATASETS']['TRAIN'][0]),
         | 
| 240 | 
            +
                        "size_divisibility": dec_cfg['SIZE_DIVISIBILITY'],
         | 
| 241 | 
            +
                        "sem_seg_postprocess_before_inference": (
         | 
| 242 | 
            +
                            dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE']
         | 
| 243 | 
            +
                            or dec_cfg['TEST']['PANOPTIC_ON']
         | 
| 244 | 
            +
                            or dec_cfg['TEST']['INSTANCE_ON']
         | 
| 245 | 
            +
                        ),
         | 
| 246 | 
            +
                        "pixel_mean": cfg['INPUT']['PIXEL_MEAN'],
         | 
| 247 | 
            +
                        "pixel_std": cfg['INPUT']['PIXEL_STD'],
         | 
| 248 | 
            +
                        "task_switch": task_switch,
         | 
| 249 | 
            +
                        "phrase_prob": phrase_prob,
         | 
| 250 | 
            +
                        # inference
         | 
| 251 | 
            +
                        "semantic_on": dec_cfg['TEST']['SEMANTIC_ON'],
         | 
| 252 | 
            +
                        "instance_on": dec_cfg['TEST']['INSTANCE_ON'],
         | 
| 253 | 
            +
                        "panoptic_on": dec_cfg['TEST']['PANOPTIC_ON'],
         | 
| 254 | 
            +
                        "test_topk_per_image": cfg['TEST']['DETECTIONS_PER_IMAGE'],
         | 
| 255 | 
            +
                        "train_dataset_name": train_dataset_name,
         | 
| 256 | 
            +
                        "interactive_mode": interactive_mode,
         | 
| 257 | 
            +
                        "interactive_iter": interactive_iter,
         | 
| 258 | 
            +
                        "dilation_kernel": dilation_kernel,
         | 
| 259 | 
            +
                        "train_max_iter": train_max_iter,
         | 
| 260 | 
            +
                    }
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                @property
         | 
| 263 | 
            +
                def device(self):
         | 
| 264 | 
            +
                    return self.pixel_mean.device
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                def forward(self, batched_inputs, mode='default'):
         | 
| 267 | 
            +
                    """
         | 
| 268 | 
            +
                    Args:
         | 
| 269 | 
            +
                        batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
         | 
| 270 | 
            +
                            Each item in the list contains the inputs for one image.
         | 
| 271 | 
            +
                            For now, each item in the list is a dict that contains:
         | 
| 272 | 
            +
                               * "image": Tensor, image in (C, H, W) format.
         | 
| 273 | 
            +
                               * "instances": per-region ground truth
         | 
| 274 | 
            +
                               * Other information that's included in the original dicts, such as:
         | 
| 275 | 
            +
                                 "height", "width" (int): the output resolution of the model (may be different
         | 
| 276 | 
            +
                                 from input resolution), used in inference.
         | 
| 277 | 
            +
                    Returns:
         | 
| 278 | 
            +
                        list[dict]:
         | 
| 279 | 
            +
                            each dict has the results for one image. The dict contains the following keys:
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                            * "sem_seg":
         | 
| 282 | 
            +
                                A Tensor that represents the
         | 
| 283 | 
            +
                                per-pixel segmentation prediced by the head.
         | 
| 284 | 
            +
                                The prediction has shape KxHxW that represents the logits of
         | 
| 285 | 
            +
                                each class for each pixel.
         | 
| 286 | 
            +
                            * "panoptic_seg":
         | 
| 287 | 
            +
                                A tuple that represent panoptic output
         | 
| 288 | 
            +
                                panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
         | 
| 289 | 
            +
                                segments_info (list[dict]): Describe each segment in `panoptic_seg`.
         | 
| 290 | 
            +
                                    Each dict contains keys "id", "category_id", "isthing".
         | 
| 291 | 
            +
                    """
         | 
| 292 | 
            +
                    if self.training:
         | 
| 293 | 
            +
                        losses = {}
         | 
| 294 | 
            +
                        if self.task_switch['mask'] or self.task_switch['grounding'] or self.task_switch['spatial']:
         | 
| 295 | 
            +
                            losses_seg = self.forward_seg(batched_inputs)
         | 
| 296 | 
            +
                            losses.update(losses_seg)
         | 
| 297 | 
            +
                        if self.task_switch['openimage'] and self.task_switch['openimage']['mask']:
         | 
| 298 | 
            +
                            losses_openimage = self.forward_openimage(batched_inputs['openimage'])
         | 
| 299 | 
            +
                            losses_openimage = {key.replace('mask', 'openimage'):value for key, value in losses_openimage.items()}
         | 
| 300 | 
            +
                            losses_openimage = {key.replace('grounding', 'grounding_openimage'):value for key, value in losses_openimage.items()}
         | 
| 301 | 
            +
                            losses.update(losses_openimage)
         | 
| 302 | 
            +
                        for k in list(losses.keys()):
         | 
| 303 | 
            +
                            if k in self.criterion.weight_dict:
         | 
| 304 | 
            +
                                losses[k] *= self.criterion.weight_dict[k]
         | 
| 305 | 
            +
                            else: # remove this loss if not specified in `weight_dict`
         | 
| 306 | 
            +
                                losses.pop(k)
         | 
| 307 | 
            +
                        return losses
         | 
| 308 | 
            +
                    else:
         | 
| 309 | 
            +
                        if mode == 'interactive':
         | 
| 310 | 
            +
                            return self.evaluate_interactive(batched_inputs)
         | 
| 311 | 
            +
                        elif mode == 'interactive_grounding':
         | 
| 312 | 
            +
                            return self.evaluate_interactive_grounding(batched_inputs)
         | 
| 313 | 
            +
                        elif mode == 'grounding_spatial':
         | 
| 314 | 
            +
                            return self.evaluate_grounding_sptial(batched_inputs, mode)
         | 
| 315 | 
            +
                        elif mode in ['grounding_phrasecut', 'grounding_refcoco']:
         | 
| 316 | 
            +
                            return self.evaluate_grounding(batched_inputs, mode)
         | 
| 317 | 
            +
                        else:
         | 
| 318 | 
            +
                            return self.evaluate(batched_inputs)
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    
         | 
| 321 | 
            +
                def forward_seg(self, batched_inputs):
         | 
| 322 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 323 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 324 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                    self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(self.train_class_names, is_eval=False)
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                    extra = {}
         | 
| 329 | 
            +
                    # mask classification target
         | 
| 330 | 
            +
                    if "instances" in batched_inputs[0]:
         | 
| 331 | 
            +
                        # input bounding box is checked to be correct.
         | 
| 332 | 
            +
                        targets = self.prepare_targets(batched_inputs, images)
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                        if self.task_switch['grounding']:
         | 
| 335 | 
            +
                            grounding_tokens = [x['grounding_query_embs'] for x in targets] # need to pad for more than one grounding token
         | 
| 336 | 
            +
                            grounding_tokens = nn.utils.rnn.pad_sequence(grounding_tokens, padding_value=-1)
         | 
| 337 | 
            +
                            non_zero_query_mask = (grounding_tokens.sum(dim=-1) == -grounding_tokens.shape[-1])
         | 
| 338 | 
            +
                            grounding_tokens[non_zero_query_mask] = 0
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                            extra['grounding_tokens'] = grounding_tokens
         | 
| 341 | 
            +
                            extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                        if self.task_switch['spatial']:
         | 
| 344 | 
            +
                            pos_masks = [x['spatial_query']['rand_shape'].to(self.device) for x in batched_inputs]
         | 
| 345 | 
            +
                            neg_masks = [(x['spatial_query']['rand_shape'].to(self.device) & False) for x in batched_inputs]
         | 
| 346 | 
            +
                            fp_masks = torch.stack([(x['spatial_query']['rand_shape'].to(self.device) & False) for x in batched_inputs])
         | 
| 347 | 
            +
                            extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks, 'false_positive_mask': fp_masks})
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                    features = self.backbone(images.tensor)
         | 
| 350 | 
            +
                    mask_features, _, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                    # forward spatial only without gradient
         | 
| 353 | 
            +
                    if self.task_switch['spatial']:
         | 
| 354 | 
            +
                        with torch.no_grad():
         | 
| 355 | 
            +
                            # generate random integeter between [0,3]
         | 
| 356 | 
            +
                            rand_iter_num = random.randint(0, self.train_max_iter)
         | 
| 357 | 
            +
                            for i in range(rand_iter_num):
         | 
| 358 | 
            +
                                outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, extra=extra, task='spatial')
         | 
| 359 | 
            +
                                extra.update(outputs)
         | 
| 360 | 
            +
                                extra.update(self.prepare_next_spaital_mask(extra, batched_inputs))
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                    outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, extra=extra, task='seg')
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                    extra = {'lang_logit': self.sem_seg_head.predictor.lang_encoder.logit_scale,
         | 
| 365 | 
            +
                             'class_embeddings': getattr(self.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('default')),
         | 
| 366 | 
            +
                             'false_positive_mask': extra['false_positive_mask']}
         | 
| 367 | 
            +
                    # bipartite matching-based loss
         | 
| 368 | 
            +
                    self.criterion.losses = self.losses['seg'] # seg criterion losses
         | 
| 369 | 
            +
                    losses = self.criterion(outputs, targets, extra)
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                    del outputs
         | 
| 372 | 
            +
                    return losses
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                def evaluate(self, batched_inputs):
         | 
| 375 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 376 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 377 | 
            +
                    
         | 
| 378 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 379 | 
            +
                    img_bs = images.tensor.shape[0]
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                    targets = targets_grounding = queries_grounding = None
         | 
| 382 | 
            +
                    features = self.backbone(images.tensor)
         | 
| 383 | 
            +
                    outputs = self.sem_seg_head(features, target_queries=queries_grounding)
         | 
| 384 | 
            +
             | 
| 385 | 
            +
                    mask_cls_results = outputs["pred_logits"]
         | 
| 386 | 
            +
                    mask_pred_results = outputs["pred_masks"]
         | 
| 387 | 
            +
                    box_pred_results = outputs["pred_boxes"] if self.task_switch['bbox'] else [None for i in range(len(mask_pred_results))]
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                    # upsample masks
         | 
| 390 | 
            +
                    mask_pred_results = F.interpolate(
         | 
| 391 | 
            +
                        mask_pred_results,
         | 
| 392 | 
            +
                        size=(images.tensor.shape[-2], images.tensor.shape[-1]),
         | 
| 393 | 
            +
                        mode="bilinear",
         | 
| 394 | 
            +
                        align_corners=False,
         | 
| 395 | 
            +
                    )
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                    input_size = mask_pred_results.shape[-2:]
         | 
| 398 | 
            +
                    del outputs
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                    processed_results = []
         | 
| 401 | 
            +
                    for mask_cls_result, mask_pred_result, box_pred_result, input_per_image, image_size in zip(
         | 
| 402 | 
            +
                        mask_cls_results, mask_pred_results, box_pred_results, batched_inputs, images.image_sizes
         | 
| 403 | 
            +
                    ):
         | 
| 404 | 
            +
                        height = input_per_image.get("height", image_size[0])
         | 
| 405 | 
            +
                        width = input_per_image.get("width", image_size[1])
         | 
| 406 | 
            +
                        processed_results.append({})
         | 
| 407 | 
            +
             | 
| 408 | 
            +
                        if self.sem_seg_postprocess_before_inference:
         | 
| 409 | 
            +
                            mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
         | 
| 410 | 
            +
                                mask_pred_result, image_size, height, width
         | 
| 411 | 
            +
                            )
         | 
| 412 | 
            +
                            mask_cls_result = mask_cls_result.to(mask_pred_result)
         | 
| 413 | 
            +
             | 
| 414 | 
            +
                        # semantic segmentation inference
         | 
| 415 | 
            +
                        if self.semantic_on:
         | 
| 416 | 
            +
                            r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result)
         | 
| 417 | 
            +
                            if not self.sem_seg_postprocess_before_inference:
         | 
| 418 | 
            +
                                r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width)
         | 
| 419 | 
            +
                            processed_results[-1]["sem_seg"] = r
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                        # panoptic segmentation inference
         | 
| 422 | 
            +
                        if self.panoptic_on:
         | 
| 423 | 
            +
                            panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result)
         | 
| 424 | 
            +
                            processed_results[-1]["panoptic_seg"] = panoptic_r
         | 
| 425 | 
            +
                        
         | 
| 426 | 
            +
                        # instance segmentation inference
         | 
| 427 | 
            +
                        if self.instance_on:
         | 
| 428 | 
            +
                            if self.task_switch['bbox']:
         | 
| 429 | 
            +
                                box_pred_result = bbox_postprocess(box_pred_result, input_size, image_size, height, width)
         | 
| 430 | 
            +
                            instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result, box_pred_result)
         | 
| 431 | 
            +
                            processed_results[-1]["instances"] = instance_r
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                    return processed_results
         | 
| 434 | 
            +
             | 
| 435 | 
            +
                def evaluate_interactive(self, batched_inputs):
         | 
| 436 | 
            +
                    assert self.task_switch['spatial']
         | 
| 437 | 
            +
                    assert 'spatial_query' in batched_inputs[0]
         | 
| 438 | 
            +
                    assert len(batched_inputs) == 1, "only support batch size equal to 1"
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 441 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 442 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 443 | 
            +
                    img_bs = images.tensor.shape[0]
         | 
| 444 | 
            +
             | 
| 445 | 
            +
                    targets = targets_grounding = queries_grounding = None
         | 
| 446 | 
            +
                    extra = {}
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                    features = self.backbone(images.tensor)
         | 
| 449 | 
            +
                    mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                    image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
         | 
| 452 | 
            +
                    nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
         | 
| 453 | 
            +
                    multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
         | 
| 454 | 
            +
                    mask_features = mask_features.repeat(nm,1,1,1)
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                    all_batch_shape_iou = []
         | 
| 457 | 
            +
                    pred_smask_pointer = None
         | 
| 458 | 
            +
                    prev_smask_pointer = None
         | 
| 459 | 
            +
                    pred_smask_all = None
         | 
| 460 | 
            +
             | 
| 461 | 
            +
                    # visualization code
         | 
| 462 | 
            +
                    # v_pred_mask = []
         | 
| 463 | 
            +
                    # v_pos_mask = []
         | 
| 464 | 
            +
                    # v_neg_mask = []
         | 
| 465 | 
            +
                    # v_gt_mask = batched_inputs[0]['spatial_query']['gt_masks'][0]
         | 
| 466 | 
            +
                    query_index = self.sem_seg_head.predictor.query_index
         | 
| 467 | 
            +
                    if self.interactive_mode in ['best', 'best_random']:
         | 
| 468 | 
            +
                        pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
         | 
| 469 | 
            +
                        pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
         | 
| 470 | 
            +
             | 
| 471 | 
            +
                        neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
         | 
| 472 | 
            +
                        neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
         | 
| 473 | 
            +
                        extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
         | 
| 474 | 
            +
                    elif self.interactive_mode == 'random':
         | 
| 475 | 
            +
                        pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==1).unbind(0)
         | 
| 476 | 
            +
                        pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
         | 
| 477 | 
            +
             | 
| 478 | 
            +
                        neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==-1).unbind(0)
         | 
| 479 | 
            +
                        neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor
         | 
| 480 | 
            +
                        extra.update({'spatial_query_pos_mask': pos_masks[:,0:1].unbind(), 'spatial_query_neg_mask': neg_masks[:,0:1].unbind()})
         | 
| 481 | 
            +
                    else:
         | 
| 482 | 
            +
                        assert False, "invalid interactive mode"
         | 
| 483 | 
            +
             | 
| 484 | 
            +
                    for i in range(self.interactive_iter):
         | 
| 485 | 
            +
                        # v_pos_mask += [extra['spatial_query_pos_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
         | 
| 486 | 
            +
                        # v_neg_mask += [extra['spatial_query_neg_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
         | 
| 487 | 
            +
                        outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
         | 
| 488 | 
            +
                        extra.update(outputs)
         | 
| 489 | 
            +
                        pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bilinear')
         | 
| 490 | 
            +
                        # v_pred_mask += [(pred_smask[0,0][:image_sizes[0][0],:image_sizes[0][1]].sigmoid() > 0.5).float().cpu().numpy()]
         | 
| 491 | 
            +
             | 
| 492 | 
            +
                        s = image_sizes[0]
         | 
| 493 | 
            +
                        b = batched_inputs[0]
         | 
| 494 | 
            +
                        pred_smask_all = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bilinear')[:,0].sigmoid() > 0.5
         | 
| 495 | 
            +
                        gt_smask = b['gt_masks_orisize']
         | 
| 496 | 
            +
                        ious = get_iou(gt_smask, pred_smask_all)
         | 
| 497 | 
            +
                        all_batch_shape_iou += [ious]
         | 
| 498 | 
            +
                        if (ious > 0.9).sum() == len(ious):
         | 
| 499 | 
            +
                            all_batch_shape_iou += [ious for j in range(self.interactive_iter-i-1)]
         | 
| 500 | 
            +
                            break
         | 
| 501 | 
            +
                        if self.interactive_mode in ['best', 'best_random']:
         | 
| 502 | 
            +
                            extra.update(self.prepare_next_spaital_mask(extra, batched_inputs, mode=self.interactive_mode))
         | 
| 503 | 
            +
                        elif self.interactive_mode == 'random':
         | 
| 504 | 
            +
                            extra.update({'spatial_query_pos_mask': pos_masks[:,i+1:i+2].unbind(), 'spatial_query_neg_mask': neg_masks[:,i+1:i+2].unbind()})
         | 
| 505 | 
            +
                        else:
         | 
| 506 | 
            +
                            assert False, "invalid interactive mode"
         | 
| 507 | 
            +
                    all_batch_shape_iou = torch.stack(all_batch_shape_iou)
         | 
| 508 | 
            +
                    processed_results = [{"mask_iou": all_batch_shape_iou[:,i]} for i in range(len(all_batch_shape_iou[0]))]
         | 
| 509 | 
            +
             | 
| 510 | 
            +
                    return processed_results
         | 
| 511 | 
            +
             | 
| 512 | 
            +
                def evaluate_interactive_single(self, batched_inputs, extra={}):
         | 
| 513 | 
            +
                    assert self.task_switch['spatial']
         | 
| 514 | 
            +
                    assert 'spatial_query' in batched_inputs[0]
         | 
| 515 | 
            +
                    assert len(batched_inputs) == 1, "only support batch size equal to 1"
         | 
| 516 | 
            +
             | 
| 517 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 518 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 519 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 520 | 
            +
                    img_bs = images.tensor.shape[0]
         | 
| 521 | 
            +
             | 
| 522 | 
            +
                    targets = targets_grounding = queries_grounding = None
         | 
| 523 | 
            +
             | 
| 524 | 
            +
                    features = self.backbone(images.tensor)
         | 
| 525 | 
            +
                    mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
         | 
| 526 | 
            +
             | 
| 527 | 
            +
                    image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
         | 
| 528 | 
            +
                    nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
         | 
| 529 | 
            +
                    multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
         | 
| 530 | 
            +
                    mask_features = mask_features.repeat(nm,1,1,1)
         | 
| 531 | 
            +
             | 
| 532 | 
            +
                    outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
         | 
| 533 | 
            +
                    pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bicubic')
         | 
| 534 | 
            +
             | 
| 535 | 
            +
                    s = image_sizes[0]
         | 
| 536 | 
            +
                    b = batched_inputs[0]
         | 
| 537 | 
            +
                    pred_smask_ori = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bicubic')[:,0].sigmoid() > 0.5
         | 
| 538 | 
            +
                    pred_smask_batch = pred_smask[:,:,:s[0],:s[1]].sigmoid() > 0.5
         | 
| 539 | 
            +
                    ious = []
         | 
| 540 | 
            +
                    if 'gt_masks_orisize' in b:
         | 
| 541 | 
            +
                        gt_smask = b['gt_masks_orisize'].to(pred_smask_ori.device)
         | 
| 542 | 
            +
                        ious = get_iou(gt_smask, pred_smask_ori)
         | 
| 543 | 
            +
                    processed_results = [{"mask_iou": ious, 'pred_mask_ori': pred_smask_ori, 'pred_mask_batch': pred_smask_batch}]
         | 
| 544 | 
            +
                    return processed_results
         | 
| 545 | 
            +
             | 
| 546 | 
            +
                def evaluate_interactive_grounding(self, batched_inputs):
         | 
| 547 | 
            +
                    assert self.task_switch['spatial']
         | 
| 548 | 
            +
                    assert 'spatial_query' in batched_inputs[0]
         | 
| 549 | 
            +
                    assert len(batched_inputs) == 1, "only support batch size equal to 1"
         | 
| 550 | 
            +
             | 
| 551 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 552 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 553 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 554 | 
            +
                    img_bs = images.tensor.shape[0]
         | 
| 555 | 
            +
             | 
| 556 | 
            +
                    targets = targets_grounding = queries_grounding = None
         | 
| 557 | 
            +
                    extra = {}
         | 
| 558 | 
            +
             | 
| 559 | 
            +
                    features = self.backbone(images.tensor)
         | 
| 560 | 
            +
                    mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
         | 
| 561 | 
            +
             | 
| 562 | 
            +
                    image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
         | 
| 563 | 
            +
                    nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
         | 
| 564 | 
            +
                    multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
         | 
| 565 | 
            +
                    mask_features = mask_features.repeat(nm,1,1,1)
         | 
| 566 | 
            +
             | 
| 567 | 
            +
                    all_batch_shape_iou = []
         | 
| 568 | 
            +
                    pred_smask_pointer = None
         | 
| 569 | 
            +
                    prev_smask_pointer = None
         | 
| 570 | 
            +
                    pred_smask_all = None
         | 
| 571 | 
            +
             | 
| 572 | 
            +
                    # visualization code
         | 
| 573 | 
            +
                    # v_pred_mask = []
         | 
| 574 | 
            +
                    # v_pos_mask = []
         | 
| 575 | 
            +
                    # v_neg_mask = []
         | 
| 576 | 
            +
                    # v_gt_mask = batched_inputs[0]['spatial_query']['gt_masks'][0]
         | 
| 577 | 
            +
                    query_index = self.sem_seg_head.predictor.query_index
         | 
| 578 | 
            +
                    if self.interactive_mode in ['best', 'best_random']:
         | 
| 579 | 
            +
                        pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
         | 
| 580 | 
            +
                        pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
         | 
| 581 | 
            +
             | 
| 582 | 
            +
                        neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
         | 
| 583 | 
            +
                        neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
         | 
| 584 | 
            +
                        extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
         | 
| 585 | 
            +
                    elif self.interactive_mode == 'random':
         | 
| 586 | 
            +
                        pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==1).unbind(0)
         | 
| 587 | 
            +
                        pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
         | 
| 588 | 
            +
             | 
| 589 | 
            +
                        neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==-1).unbind(0)
         | 
| 590 | 
            +
                        neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor
         | 
| 591 | 
            +
                        extra.update({'spatial_query_pos_mask': pos_masks[:,0:1].unbind(), 'spatial_query_neg_mask': neg_masks[:,0:1].unbind()})
         | 
| 592 | 
            +
                    else:
         | 
| 593 | 
            +
                        assert False, "invalid interactive mode"
         | 
| 594 | 
            +
             | 
| 595 | 
            +
                    grd_texts = batched_inputs[0]['classes']
         | 
| 596 | 
            +
                    gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
         | 
| 597 | 
            +
                    token_emb = gtext['token_emb']
         | 
| 598 | 
            +
                    tokens = gtext['tokens']
         | 
| 599 | 
            +
                    query_emb = nn.utils.rnn.pad_sequence([_token_emb[_tokens.bool()] for _token_emb, _tokens in zip(token_emb, tokens['attention_mask'])], padding_value=-1)
         | 
| 600 | 
            +
                    non_zero_query_mask = (query_emb.sum(dim=-1) < 0)
         | 
| 601 | 
            +
             | 
| 602 | 
            +
                    extra['grounding_tokens'] = query_emb
         | 
| 603 | 
            +
                    extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
         | 
| 604 | 
            +
             | 
| 605 | 
            +
                    for i in range(self.interactive_iter):
         | 
| 606 | 
            +
                        # v_pos_mask += [extra['spatial_query_pos_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
         | 
| 607 | 
            +
                        # v_neg_mask += [extra['spatial_query_neg_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
         | 
| 608 | 
            +
                        outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
         | 
| 609 | 
            +
                        extra.update(outputs)
         | 
| 610 | 
            +
                        pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bilinear')
         | 
| 611 | 
            +
                        # v_pred_mask += [(pred_smask[0,0][:image_sizes[0][0],:image_sizes[0][1]].sigmoid() > 0.5).float().cpu().numpy()]
         | 
| 612 | 
            +
             | 
| 613 | 
            +
                        s = image_sizes[0]
         | 
| 614 | 
            +
                        b = batched_inputs[0]
         | 
| 615 | 
            +
                        pred_smask_all = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bilinear')[:,0].sigmoid() > 0.5
         | 
| 616 | 
            +
                        gt_smask = b['gt_masks_orisize']
         | 
| 617 | 
            +
                        ious = get_iou(gt_smask, pred_smask_all)
         | 
| 618 | 
            +
                        all_batch_shape_iou += [ious]
         | 
| 619 | 
            +
                        if (ious > 0.9).sum() == len(ious):
         | 
| 620 | 
            +
                            all_batch_shape_iou += [ious for j in range(self.interactive_iter-i-1)]
         | 
| 621 | 
            +
                            break
         | 
| 622 | 
            +
                        if self.interactive_mode in ['best', 'best_random']:
         | 
| 623 | 
            +
                            extra.update(self.prepare_next_spaital_mask(extra, batched_inputs, mode=self.interactive_mode))
         | 
| 624 | 
            +
                        elif self.interactive_mode == 'random':
         | 
| 625 | 
            +
                            extra.update({'spatial_query_pos_mask': pos_masks[:,i+1:i+2].unbind(), 'spatial_query_neg_mask': neg_masks[:,i+1:i+2].unbind()})
         | 
| 626 | 
            +
                        else:
         | 
| 627 | 
            +
                            assert False, "invalid interactive mode"
         | 
| 628 | 
            +
                    all_batch_shape_iou = torch.stack(all_batch_shape_iou)
         | 
| 629 | 
            +
                    processed_results = [{"mask_iou": all_batch_shape_iou[:,i]} for i in range(len(all_batch_shape_iou[0]))]
         | 
| 630 | 
            +
             | 
| 631 | 
            +
                    # visualization
         | 
| 632 | 
            +
                    # VL.step()
         | 
| 633 | 
            +
                    # import cv2
         | 
| 634 | 
            +
                    # v_masks = []
         | 
| 635 | 
            +
                    # v_pos_masks = []
         | 
| 636 | 
            +
                    # v_neg_masks = []
         | 
| 637 | 
            +
                    # txt = []
         | 
| 638 | 
            +
             | 
| 639 | 
            +
                    # img = batched_inputs[0]['image'].permute(1,2,0).cpu().numpy()
         | 
| 640 | 
            +
                    # mask_img = VL.overlay_single_mask_to_image(img[:,:,::-1], v_gt_mask.cpu().float().numpy())
         | 
| 641 | 
            +
                    # acc_pos_mask = np.zeros(v_pos_mask[0].shape)
         | 
| 642 | 
            +
                    # acc_neg_mask = np.zeros(v_neg_mask[0].shape)
         | 
| 643 | 
            +
                    # for x,y,z,iou in zip(v_pos_mask, v_neg_mask, v_pred_mask, all_batch_shape_iou):
         | 
| 644 | 
            +
                    #     # dilate x,y
         | 
| 645 | 
            +
                    #     x = cv2.dilate(x, np.ones((5,5), np.uint8), iterations=3)
         | 
| 646 | 
            +
                    #     y = cv2.dilate(y, np.ones((5,5), np.uint8), iterations=3)
         | 
| 647 | 
            +
                    #     acc_pos_mask += x
         | 
| 648 | 
            +
                    #     acc_neg_mask += y
         | 
| 649 | 
            +
             | 
| 650 | 
            +
                    #     v_masks += [z]
         | 
| 651 | 
            +
                    #     v_pos_masks += [acc_pos_mask.clip(0,1)]
         | 
| 652 | 
            +
                    #     v_neg_masks += [acc_neg_mask.clip(0,1)]
         | 
| 653 | 
            +
                    #     txt += ["pred_{}".format(str(iou[0].item())[0:5])]
         | 
| 654 | 
            +
             | 
| 655 | 
            +
                    # VL.add_image(img[:,:,::-1])
         | 
| 656 | 
            +
                    # VL.insert(mask_img, "gt_mask")
         | 
| 657 | 
            +
                    # VL.overlay_obj_mask_to_image_withposneg(img[:,:,::-1], v_masks, v_pos_masks, v_neg_masks, txt, max_len=20)
         | 
| 658 | 
            +
                    return processed_results
         | 
| 659 | 
            +
             | 
| 660 | 
            +
                def evaluate_referring_image(self, batched_inputs, extra={}):
         | 
| 661 | 
            +
                    assert self.task_switch['spatial']
         | 
| 662 | 
            +
                    assert len(batched_inputs) == 1, "only support batch size equal to 1"
         | 
| 663 | 
            +
                    assert self.interactive_mode == 'best'
         | 
| 664 | 
            +
             | 
| 665 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 666 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 667 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 668 | 
            +
                    img_bs = images.tensor.shape[0]
         | 
| 669 | 
            +
             | 
| 670 | 
            +
                    targets = targets_grounding = queries_grounding = None
         | 
| 671 | 
            +
                    features = self.backbone(images.tensor)
         | 
| 672 | 
            +
                    mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
         | 
| 673 | 
            +
             | 
| 674 | 
            +
                    if 'spatial_query' in batched_inputs[0]:
         | 
| 675 | 
            +
                        image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
         | 
| 676 | 
            +
                        nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
         | 
| 677 | 
            +
                        multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
         | 
| 678 | 
            +
                        mask_features = mask_features.repeat(nm,1,1,1)
         | 
| 679 | 
            +
             | 
| 680 | 
            +
                        query_index = self.sem_seg_head.predictor.query_index
         | 
| 681 | 
            +
                        pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
         | 
| 682 | 
            +
                        pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
         | 
| 683 | 
            +
             | 
| 684 | 
            +
                        neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
         | 
| 685 | 
            +
                        neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
         | 
| 686 | 
            +
                        extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
         | 
| 687 | 
            +
             | 
| 688 | 
            +
                    outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='refimg')
         | 
| 689 | 
            +
                    return outputs, images.tensor.shape
         | 
| 690 | 
            +
             | 
| 691 | 
            +
                def evaluate_grounding(self, batched_inputs, mode):
         | 
| 692 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 693 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 694 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 695 | 
            +
                    assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
         | 
| 696 | 
            +
             | 
| 697 | 
            +
                    extra = {}
         | 
| 698 | 
            +
                    # mask_pred_results = []
         | 
| 699 | 
            +
                    # for idx, batch_per_image in enumerate(batched_inputs):
         | 
| 700 | 
            +
                    #     grd_texts = batch_per_image['groundings']['texts']
         | 
| 701 | 
            +
                    #     grd_masks = []
         | 
| 702 | 
            +
                    #     for anno_text in grd_texts:
         | 
| 703 | 
            +
                    #         gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
         | 
| 704 | 
            +
                    #         token_emb = gtext['token_emb']
         | 
| 705 | 
            +
                    #         tokens = gtext['tokens']
         | 
| 706 | 
            +
                        
         | 
| 707 | 
            +
                    #         grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
         | 
| 708 | 
            +
                    #         extra['grounding_tokens'] = grd_emb[:,None]
         | 
| 709 | 
            +
             | 
| 710 | 
            +
                    #         assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
         | 
| 711 | 
            +
                    #         features = self.backbone(images.tensor)
         | 
| 712 | 
            +
                    #         outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
         | 
| 713 | 
            +
                            
         | 
| 714 | 
            +
                    #         pred_gmasks = outputs['pred_masks'][idx,self.num_queries:2*self.num_queries-1]
         | 
| 715 | 
            +
                    #         v_emb = outputs['pred_captions'][idx,self.num_queries:2*self.num_queries-1]
         | 
| 716 | 
            +
                    #         t_emb = grd_emb[-1:]
         | 
| 717 | 
            +
             | 
| 718 | 
            +
                    #         t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 719 | 
            +
                    #         v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)            
         | 
| 720 | 
            +
             | 
| 721 | 
            +
                    #         temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
         | 
| 722 | 
            +
                    #         out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
         | 
| 723 | 
            +
                            
         | 
| 724 | 
            +
                    #         matched_id = out_prob.max(0)[1]
         | 
| 725 | 
            +
                    #         grd_masks += [pred_gmasks[matched_id,:,:]]
         | 
| 726 | 
            +
                    #     mask_pred_results += [torch.cat(grd_masks)]
         | 
| 727 | 
            +
             | 
| 728 | 
            +
                    # comment for multi object inference.
         | 
| 729 | 
            +
                    mask_pred_results = []
         | 
| 730 | 
            +
                    for idx, batch_per_image in enumerate(batched_inputs):
         | 
| 731 | 
            +
                        grd_texts = batch_per_image['groundings']['texts']
         | 
| 732 | 
            +
                        grd_texts = [x[0] for x in grd_texts]
         | 
| 733 | 
            +
             | 
| 734 | 
            +
                        gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
         | 
| 735 | 
            +
                        token_emb = gtext['token_emb']
         | 
| 736 | 
            +
                        tokens = gtext['tokens']
         | 
| 737 | 
            +
                        query_emb = token_emb[tokens['attention_mask'].bool()]
         | 
| 738 | 
            +
                        non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
         | 
| 739 | 
            +
             | 
| 740 | 
            +
                        extra['grounding_tokens'] = query_emb[:,None]
         | 
| 741 | 
            +
                        extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
         | 
| 742 | 
            +
             | 
| 743 | 
            +
                        features = self.backbone(images.tensor)
         | 
| 744 | 
            +
                        outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
         | 
| 745 | 
            +
             | 
| 746 | 
            +
                        pred_gmasks = outputs['pred_gmasks'][idx]
         | 
| 747 | 
            +
                        v_emb = outputs['pred_gtexts'][idx]
         | 
| 748 | 
            +
                        t_emb = gtext['class_emb']
         | 
| 749 | 
            +
             | 
| 750 | 
            +
                        t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 751 | 
            +
                        v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)            
         | 
| 752 | 
            +
             | 
| 753 | 
            +
                        temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
         | 
| 754 | 
            +
                        out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
         | 
| 755 | 
            +
                        
         | 
| 756 | 
            +
                        matched_id = out_prob.max(0)[1]
         | 
| 757 | 
            +
                        mask_pred_results += [pred_gmasks[matched_id,:,:]]
         | 
| 758 | 
            +
             | 
| 759 | 
            +
                    for i in range(len(mask_pred_results)):
         | 
| 760 | 
            +
                        # upsample masks
         | 
| 761 | 
            +
                        mask_pred_results[i] = F.interpolate(
         | 
| 762 | 
            +
                            mask_pred_results[i][None,],
         | 
| 763 | 
            +
                            size=(images.tensor.shape[-2], images.tensor.shape[-1]),
         | 
| 764 | 
            +
                            mode="bilinear",
         | 
| 765 | 
            +
                            align_corners=False,
         | 
| 766 | 
            +
                        )[0]
         | 
| 767 | 
            +
             | 
| 768 | 
            +
                    processed_results = []
         | 
| 769 | 
            +
                    for mask_pred_result, input_per_image, image_size in zip(
         | 
| 770 | 
            +
                        mask_pred_results, batched_inputs, images.image_sizes
         | 
| 771 | 
            +
                    ):
         | 
| 772 | 
            +
                        height = input_per_image.get("height", image_size[0])
         | 
| 773 | 
            +
                        width = input_per_image.get("width", image_size[1])
         | 
| 774 | 
            +
                        processed_results.append({})
         | 
| 775 | 
            +
             | 
| 776 | 
            +
                        mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
         | 
| 777 | 
            +
                            mask_pred_result, image_size, height, width
         | 
| 778 | 
            +
                        )
         | 
| 779 | 
            +
                        processed_results[-1]['grounding_mask'] = mask_pred_result
         | 
| 780 | 
            +
             | 
| 781 | 
            +
                        # compute bbox
         | 
| 782 | 
            +
                        # bbox = BitMasks(mask_pred_result > 0).get_bounding_boxes()
         | 
| 783 | 
            +
                        # bbox = BoxMode.convert(bbox.tensor, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
         | 
| 784 | 
            +
                        # processed_results[-1]['grounding_box'] = bbox
         | 
| 785 | 
            +
             | 
| 786 | 
            +
                    return processed_results
         | 
| 787 | 
            +
             | 
| 788 | 
            +
                def evaluate_grounding_sptial(self, batched_inputs, mode):
         | 
| 789 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 790 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 791 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 792 | 
            +
                    assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
         | 
| 793 | 
            +
             | 
| 794 | 
            +
                    extra = {}
         | 
| 795 | 
            +
                    dilation = 3
         | 
| 796 | 
            +
                    pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
         | 
| 797 | 
            +
                    pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
         | 
| 798 | 
            +
                    pos_masks = (F.conv2d(pos_masks.float(), self.dilation_kernel, padding=dilation//2) > 0).unbind(0)
         | 
| 799 | 
            +
             | 
| 800 | 
            +
                    neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
         | 
| 801 | 
            +
                    neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
         | 
| 802 | 
            +
             | 
| 803 | 
            +
                    mask_pred_results = []
         | 
| 804 | 
            +
                    for idx, batch_per_image in enumerate(batched_inputs):
         | 
| 805 | 
            +
                        grd_texts = batch_per_image['groundings']['texts']
         | 
| 806 | 
            +
                        grd_masks = []
         | 
| 807 | 
            +
                        for idx2, anno_text in enumerate(grd_texts):
         | 
| 808 | 
            +
                            extra.update({'spatial_query_pos_mask': [pos_masks[idx2]], 'spatial_query_neg_mask': [neg_masks[idx2]]})
         | 
| 809 | 
            +
             | 
| 810 | 
            +
                            gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
         | 
| 811 | 
            +
                            token_emb = gtext['token_emb']
         | 
| 812 | 
            +
                            tokens = gtext['tokens']
         | 
| 813 | 
            +
                        
         | 
| 814 | 
            +
                            grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
         | 
| 815 | 
            +
                            non_zero_query_mask = torch.zeros(grd_emb[:,None].shape[:-1], dtype=torch.bool, device=grd_emb.device)
         | 
| 816 | 
            +
                            extra['grounding_tokens'] = grd_emb[:,None]
         | 
| 817 | 
            +
                            extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
         | 
| 818 | 
            +
             | 
| 819 | 
            +
                            assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
         | 
| 820 | 
            +
                            features = self.backbone(images.tensor)
         | 
| 821 | 
            +
                            outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
         | 
| 822 | 
            +
                            
         | 
| 823 | 
            +
                            pred_gmasks = outputs['pred_gmasks'][idx]
         | 
| 824 | 
            +
                            v_emb = outputs['pred_gtexts'][idx]
         | 
| 825 | 
            +
                            t_emb = gtext['class_emb']
         | 
| 826 | 
            +
             | 
| 827 | 
            +
                            t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 828 | 
            +
                            v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)            
         | 
| 829 | 
            +
             | 
| 830 | 
            +
                            temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
         | 
| 831 | 
            +
                            out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
         | 
| 832 | 
            +
                            
         | 
| 833 | 
            +
                            matched_id = out_prob.max(0)[1]
         | 
| 834 | 
            +
                            grd_masks += [pred_gmasks[matched_id,:,:]]
         | 
| 835 | 
            +
                            # grd_masks += [outputs['prev_mask'][0]]
         | 
| 836 | 
            +
             | 
| 837 | 
            +
                        mask_pred_results += [torch.cat(grd_masks)]
         | 
| 838 | 
            +
             | 
| 839 | 
            +
                    # comment for multi object inference.
         | 
| 840 | 
            +
                    # mask_pred_results = []
         | 
| 841 | 
            +
                    # for idx, batch_per_image in enumerate(batched_inputs):
         | 
| 842 | 
            +
                    #     grd_texts = batch_per_image['groundings']['texts']
         | 
| 843 | 
            +
                    #     grd_texts = [x[0] for x in grd_texts]
         | 
| 844 | 
            +
             | 
| 845 | 
            +
                    #     gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
         | 
| 846 | 
            +
                    #     token_emb = gtext['token_emb']
         | 
| 847 | 
            +
                    #     tokens = gtext['tokens']
         | 
| 848 | 
            +
                    #     query_emb = token_emb[tokens['attention_mask'].bool()]
         | 
| 849 | 
            +
                    #     non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
         | 
| 850 | 
            +
             | 
| 851 | 
            +
                    #     extra['grounding_tokens'] = query_emb[:,None]
         | 
| 852 | 
            +
                    #     extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
         | 
| 853 | 
            +
             | 
| 854 | 
            +
                    #     features = self.backbone(images.tensor)
         | 
| 855 | 
            +
                    #     outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
         | 
| 856 | 
            +
             | 
| 857 | 
            +
                    #     pred_gmasks = outputs['pred_gmasks'][idx]
         | 
| 858 | 
            +
                    #     v_emb = outputs['pred_gtexts'][idx]
         | 
| 859 | 
            +
                    #     t_emb = gtext['class_emb']
         | 
| 860 | 
            +
             | 
| 861 | 
            +
                    #     t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 862 | 
            +
                    #     v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)            
         | 
| 863 | 
            +
             | 
| 864 | 
            +
                    #     temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
         | 
| 865 | 
            +
                    #     out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
         | 
| 866 | 
            +
                        
         | 
| 867 | 
            +
                    #     matched_id = out_prob.max(0)[1]
         | 
| 868 | 
            +
                    #     mask_pred_results += [pred_gmasks[matched_id,:,:]]
         | 
| 869 | 
            +
             | 
| 870 | 
            +
                    for i in range(len(mask_pred_results)):
         | 
| 871 | 
            +
                        # upsample masks
         | 
| 872 | 
            +
                        mask_pred_results[i] = F.interpolate(
         | 
| 873 | 
            +
                            mask_pred_results[i][None,],
         | 
| 874 | 
            +
                            size=(images.tensor.shape[-2], images.tensor.shape[-1]),
         | 
| 875 | 
            +
                            mode="bilinear",
         | 
| 876 | 
            +
                            align_corners=False,
         | 
| 877 | 
            +
                        )[0]
         | 
| 878 | 
            +
             | 
| 879 | 
            +
                    processed_results = []
         | 
| 880 | 
            +
                    for mask_pred_result, input_per_image, image_size in zip(
         | 
| 881 | 
            +
                        mask_pred_results, batched_inputs, images.image_sizes
         | 
| 882 | 
            +
                    ):
         | 
| 883 | 
            +
                        height = input_per_image.get("height", image_size[0])
         | 
| 884 | 
            +
                        width = input_per_image.get("width", image_size[1])
         | 
| 885 | 
            +
                        processed_results.append({})
         | 
| 886 | 
            +
             | 
| 887 | 
            +
                        mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
         | 
| 888 | 
            +
                            mask_pred_result, image_size, height, width
         | 
| 889 | 
            +
                        )
         | 
| 890 | 
            +
                        processed_results[-1]['grounding_mask'] = mask_pred_result
         | 
| 891 | 
            +
             | 
| 892 | 
            +
                    return processed_results
         | 
| 893 | 
            +
             | 
| 894 | 
            +
                def prepare_targets(self, batched_inputs, images):
         | 
| 895 | 
            +
                    h_pad, w_pad = images.tensor.shape[-2:]
         | 
| 896 | 
            +
                    new_targets = []
         | 
| 897 | 
            +
                    for idx, batch_per_image in enumerate(batched_inputs):
         | 
| 898 | 
            +
                        targets_per_image = batch_per_image['instances'].to(self.device)
         | 
| 899 | 
            +
                        # pad gt
         | 
| 900 | 
            +
                        gt_masks = targets_per_image.gt_masks.tensor
         | 
| 901 | 
            +
                        padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
         | 
| 902 | 
            +
                        padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
         | 
| 903 | 
            +
             | 
| 904 | 
            +
                        gt_boxes = targets_per_image.gt_boxes.tensor
         | 
| 905 | 
            +
                        ratio = torch.tensor([w_pad,h_pad,w_pad,h_pad]).to(gt_boxes.device)[None,:]
         | 
| 906 | 
            +
                        gt_boxes = gt_boxes / ratio
         | 
| 907 | 
            +
                        xc,yc,w,h = (gt_boxes[:,0] + gt_boxes[:,2])/2, (gt_boxes[:,1] + gt_boxes[:,3])/2, gt_boxes[:,2] - gt_boxes[:,0], gt_boxes[:,3] - gt_boxes[:,1]
         | 
| 908 | 
            +
                        gt_boxes = torch.stack([xc,yc,w,h]).permute(1,0)
         | 
| 909 | 
            +
             | 
| 910 | 
            +
                        target_dict = {
         | 
| 911 | 
            +
                                "labels": targets_per_image.gt_classes,
         | 
| 912 | 
            +
                                "is_things": targets_per_image.is_things,
         | 
| 913 | 
            +
                                "masks": padded_masks,
         | 
| 914 | 
            +
                                "boxes": gt_boxes,
         | 
| 915 | 
            +
                                }
         | 
| 916 | 
            +
             | 
| 917 | 
            +
                        if self.task_switch['spatial']:
         | 
| 918 | 
            +
                            # prepare targets for spatial query
         | 
| 919 | 
            +
                            target_dict['gt_spatial_masks'] = batch_per_image['spatial_query']['gt_masks']
         | 
| 920 | 
            +
             | 
| 921 | 
            +
                        if self.task_switch['grounding']:
         | 
| 922 | 
            +
                            grd_masks = batch_per_image['groundings']['masks']
         | 
| 923 | 
            +
                            grd_texts = batch_per_image['groundings']['texts']
         | 
| 924 | 
            +
                            grd_hash = batch_per_image['groundings']['hash']
         | 
| 925 | 
            +
                            grd_task = batch_per_image['groundings']['mode']
         | 
| 926 | 
            +
                            
         | 
| 927 | 
            +
                            if len(grd_masks) == 0:
         | 
| 928 | 
            +
                                padded_masks = None
         | 
| 929 | 
            +
                            else:
         | 
| 930 | 
            +
                                padded_masks = torch.zeros((grd_masks.shape[0], h_pad, w_pad), dtype=grd_masks.dtype, device=grd_masks.device)
         | 
| 931 | 
            +
                                padded_masks[:, : grd_masks.shape[1], : grd_masks.shape[2]] = grd_masks
         | 
| 932 | 
            +
             | 
| 933 | 
            +
                            gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
         | 
| 934 | 
            +
                            token_emb = gtext['token_emb']
         | 
| 935 | 
            +
                            tokens = gtext['tokens']
         | 
| 936 | 
            +
                            
         | 
| 937 | 
            +
                            unique_hash_id = np.unique(grd_hash, return_index=True)[1]
         | 
| 938 | 
            +
                            selected_mask = np.zeros(len(grd_hash)).astype(np.bool)
         | 
| 939 | 
            +
                            selected_mask[unique_hash_id] = True
         | 
| 940 | 
            +
             | 
| 941 | 
            +
                            selected_token_emb = token_emb[selected_mask]
         | 
| 942 | 
            +
                            selected_attn_mask = tokens['attention_mask'][selected_mask]
         | 
| 943 | 
            +
                            query_emb = selected_token_emb[selected_attn_mask.bool()]
         | 
| 944 | 
            +
                            
         | 
| 945 | 
            +
                            class_idx = tokens['attention_mask'].sum(dim=-1) - 1
         | 
| 946 | 
            +
                            class_idx = torch.stack((torch.arange(len(class_idx), device=class_idx.device), class_idx)).tolist()
         | 
| 947 | 
            +
                            class_emb = token_emb[class_idx]
         | 
| 948 | 
            +
                            
         | 
| 949 | 
            +
                            target_dict['grounding_masks'] = padded_masks
         | 
| 950 | 
            +
                            target_dict['grounding_query_embs'] = query_emb
         | 
| 951 | 
            +
                            target_dict['grounding_class_embs'] = class_emb
         | 
| 952 | 
            +
                            target_dict['grounding_hash'] = grd_hash
         | 
| 953 | 
            +
                            target_dict['grounding_task'] = grd_task
         | 
| 954 | 
            +
             | 
| 955 | 
            +
                        new_targets.append(target_dict)
         | 
| 956 | 
            +
                    return new_targets
         | 
| 957 | 
            +
             | 
| 958 | 
            +
                def prepare_next_spaital_mask(self, outputs, batched_inputs, mode='best'):
         | 
| 959 | 
            +
                    gt_masks = [batched_inputs[i]['spatial_query']['gt_masks'] for i in range(len(batched_inputs))]
         | 
| 960 | 
            +
                    if self.training:
         | 
| 961 | 
            +
                        gt_masks = ImageList.from_tensors(gt_masks, self.size_divisibility).tensor
         | 
| 962 | 
            +
                    else:
         | 
| 963 | 
            +
                        gt_masks = ImageList.from_tensors(gt_masks, self.size_divisibility).tensor.transpose(0,1)
         | 
| 964 | 
            +
             | 
| 965 | 
            +
                    pred_masks = (F.interpolate(outputs['prev_mask'], size=gt_masks.shape[-2:], mode='bilinear', align_corners=False).sigmoid() > 0.5)
         | 
| 966 | 
            +
                    prev_masks = torch.stack(outputs['spatial_query_pos_mask']) | torch.stack(outputs['spatial_query_neg_mask'])
         | 
| 967 | 
            +
             | 
| 968 | 
            +
                    fn = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks) # fn: False Negative, gt:1, pred:0, prev:0
         | 
| 969 | 
            +
                    fp = (~gt_masks & pred_masks) & (~prev_masks) # fp: False Positive, gt:0, pred:1, prev:0
         | 
| 970 | 
            +
             | 
| 971 | 
            +
                    # compute iou between gt and pred
         | 
| 972 | 
            +
                    iou = (gt_masks & pred_masks).sum(list(range(1,len(fn.shape)))) / ((gt_masks | pred_masks).sum(dim=list(range(1,len(fn.shape)))) + 1e-8)
         | 
| 973 | 
            +
                    fn_sum = fn.sum(dim=list(range(1,len(fn.shape))))
         | 
| 974 | 
            +
                    fp_sum = fp.sum(dim=list(range(1,len(fp.shape))))
         | 
| 975 | 
            +
             | 
| 976 | 
            +
                    is_postive = fn_sum > fp_sum
         | 
| 977 | 
            +
                    # is_postive = torch.ones(len(fn_sum), device=torch.cuda.current_device()).bool()
         | 
| 978 | 
            +
                    select_mask = torch.stack([fn[i] if is_postive[i] else fp[i] for i in range(len(fn))])
         | 
| 979 | 
            +
             | 
| 980 | 
            +
                    # conv implementation
         | 
| 981 | 
            +
                    n,_,h,w = select_mask.shape
         | 
| 982 | 
            +
                    mask_dt = (distance_transform((~F.pad(select_mask, pad=(1, 1, 1, 1), mode='constant', value=0)).float())[:,:,1:-1,1:-1]).reshape(n,-1)
         | 
| 983 | 
            +
                    if mode == 'best':
         | 
| 984 | 
            +
                        max_xy_idx = torch.stack([torch.arange(n), mask_dt.max(dim=-1)[1].cpu()]).tolist()
         | 
| 985 | 
            +
                    elif mode == 'best_random':
         | 
| 986 | 
            +
                        max_xy_idx = torch.stack([torch.arange(n), torch.cat([(mask_dt[i] > 0).nonzero()[torch.randint(0, len((mask_dt[i] > 0).nonzero()), (1,))][0] for i in range(len(mask_dt))]).cpu()]).tolist()
         | 
| 987 | 
            +
                    next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
         | 
| 988 | 
            +
                    next_mask = next_mask.view(n,-1)
         | 
| 989 | 
            +
                    next_mask[max_xy_idx] = True
         | 
| 990 | 
            +
                    next_mask = next_mask.reshape((n,1,h,w)).float()
         | 
| 991 | 
            +
                    dilation = 3
         | 
| 992 | 
            +
                    next_mask = F.conv2d(next_mask, self.dilation_kernel, padding=dilation//2) > 0
         | 
| 993 | 
            +
             | 
| 994 | 
            +
                    # determine whether next mask is zero
         | 
| 995 | 
            +
                    keep = (iou < 0.925)
         | 
| 996 | 
            +
                    next_mask = next_mask & keep.view(-1,1,1,1)
         | 
| 997 | 
            +
             | 
| 998 | 
            +
                    pos_mask = []
         | 
| 999 | 
            +
                    neg_mask = []
         | 
| 1000 | 
            +
                    for idx, ip in enumerate(is_postive):
         | 
| 1001 | 
            +
                        if ip:
         | 
| 1002 | 
            +
                            pos_mask += [outputs['spatial_query_pos_mask'][idx] | next_mask[idx]]
         | 
| 1003 | 
            +
                            neg_mask += [outputs['spatial_query_neg_mask'][idx]]
         | 
| 1004 | 
            +
                        else:
         | 
| 1005 | 
            +
                            pos_mask += [outputs['spatial_query_pos_mask'][idx]]
         | 
| 1006 | 
            +
                            neg_mask += [outputs['spatial_query_neg_mask'][idx] | next_mask[idx]]
         | 
| 1007 | 
            +
                    
         | 
| 1008 | 
            +
                    if 'false_positive_mask' in outputs:
         | 
| 1009 | 
            +
                        fp = outputs['false_positive_mask'] | fp
         | 
| 1010 | 
            +
                    return {'spatial_query_pos_mask': pos_mask, 'spatial_query_neg_mask': neg_mask, 'false_positive_mask': fp}
         | 
| 1011 | 
            +
             | 
| 1012 | 
            +
                def semantic_inference(self, mask_cls, mask_pred):
         | 
| 1013 | 
            +
                    mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
         | 
| 1014 | 
            +
                    mask_pred = mask_pred.sigmoid()
         | 
| 1015 | 
            +
                    semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
         | 
| 1016 | 
            +
                    return semseg
         | 
| 1017 | 
            +
             | 
| 1018 | 
            +
                def panoptic_inference(self, mask_cls, mask_pred):
         | 
| 1019 | 
            +
                    scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
         | 
| 1020 | 
            +
                    mask_pred = mask_pred.sigmoid()
         | 
| 1021 | 
            +
             | 
| 1022 | 
            +
                    keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
         | 
| 1023 | 
            +
                    cur_scores = scores[keep]
         | 
| 1024 | 
            +
                    cur_classes = labels[keep]
         | 
| 1025 | 
            +
                    cur_masks = mask_pred[keep]
         | 
| 1026 | 
            +
                    cur_mask_cls = mask_cls[keep]
         | 
| 1027 | 
            +
                    cur_mask_cls = cur_mask_cls[:, :-1]
         | 
| 1028 | 
            +
             | 
| 1029 | 
            +
                    cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
         | 
| 1030 | 
            +
             | 
| 1031 | 
            +
                    h, w = cur_masks.shape[-2:]
         | 
| 1032 | 
            +
                    panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
         | 
| 1033 | 
            +
                    segments_info = []
         | 
| 1034 | 
            +
             | 
| 1035 | 
            +
                    current_segment_id = 0
         | 
| 1036 | 
            +
             | 
| 1037 | 
            +
                    if cur_masks.shape[0] == 0:
         | 
| 1038 | 
            +
                        # We didn't detect any mask :(
         | 
| 1039 | 
            +
                        return panoptic_seg, segments_info
         | 
| 1040 | 
            +
                    else:
         | 
| 1041 | 
            +
                        # take argmax
         | 
| 1042 | 
            +
                        cur_mask_ids = cur_prob_masks.argmax(0)
         | 
| 1043 | 
            +
                        stuff_memory_list = {}
         | 
| 1044 | 
            +
                        for k in range(cur_classes.shape[0]):
         | 
| 1045 | 
            +
                            pred_class = cur_classes[k].item()
         | 
| 1046 | 
            +
                            isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values()
         | 
| 1047 | 
            +
                            mask_area = (cur_mask_ids == k).sum().item()
         | 
| 1048 | 
            +
                            original_area = (cur_masks[k] >= 0.5).sum().item()
         | 
| 1049 | 
            +
                            mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
         | 
| 1050 | 
            +
             | 
| 1051 | 
            +
                            if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
         | 
| 1052 | 
            +
                                if mask_area / original_area < self.overlap_threshold:
         | 
| 1053 | 
            +
                                    continue
         | 
| 1054 | 
            +
             | 
| 1055 | 
            +
                                # merge stuff regions
         | 
| 1056 | 
            +
                                if not isthing:
         | 
| 1057 | 
            +
                                    if int(pred_class) in stuff_memory_list.keys():
         | 
| 1058 | 
            +
                                        panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
         | 
| 1059 | 
            +
                                        continue
         | 
| 1060 | 
            +
                                    else:
         | 
| 1061 | 
            +
                                        stuff_memory_list[int(pred_class)] = current_segment_id + 1
         | 
| 1062 | 
            +
             | 
| 1063 | 
            +
                                current_segment_id += 1
         | 
| 1064 | 
            +
                                panoptic_seg[mask] = current_segment_id
         | 
| 1065 | 
            +
             | 
| 1066 | 
            +
                                segments_info.append(
         | 
| 1067 | 
            +
                                    {
         | 
| 1068 | 
            +
                                        "id": current_segment_id,
         | 
| 1069 | 
            +
                                        "isthing": bool(isthing),
         | 
| 1070 | 
            +
                                        "category_id": int(pred_class),
         | 
| 1071 | 
            +
                                    }
         | 
| 1072 | 
            +
                                )
         | 
| 1073 | 
            +
             | 
| 1074 | 
            +
                        return panoptic_seg, segments_info
         | 
| 1075 | 
            +
             | 
| 1076 | 
            +
                def instance_inference(self, mask_cls, mask_pred, box_pred):
         | 
| 1077 | 
            +
                    # mask_pred is already processed to have the same shape as original input
         | 
| 1078 | 
            +
                    image_size = mask_pred.shape[-2:]
         | 
| 1079 | 
            +
             | 
| 1080 | 
            +
                    # [Q, K]
         | 
| 1081 | 
            +
                    scores = F.softmax(mask_cls, dim=-1)[:, :-1]
         | 
| 1082 | 
            +
                    labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
         | 
| 1083 | 
            +
                    # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
         | 
| 1084 | 
            +
                    scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
         | 
| 1085 | 
            +
             | 
| 1086 | 
            +
                    labels_per_image = labels[topk_indices]
         | 
| 1087 | 
            +
                    topk_indices = (topk_indices // self.sem_seg_head.num_classes)
         | 
| 1088 | 
            +
                    # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
         | 
| 1089 | 
            +
                    mask_pred = mask_pred[topk_indices]
         | 
| 1090 | 
            +
                    if box_pred is not None:
         | 
| 1091 | 
            +
                        box_pred = box_pred[topk_indices]
         | 
| 1092 | 
            +
             | 
| 1093 | 
            +
                    # if this is panoptic segmentation, we only keep the "thing" classes
         | 
| 1094 | 
            +
                    if self.panoptic_on:
         | 
| 1095 | 
            +
                        keep = torch.zeros_like(scores_per_image).bool()
         | 
| 1096 | 
            +
                        for i, lab in enumerate(labels_per_image):
         | 
| 1097 | 
            +
                            keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values()
         | 
| 1098 | 
            +
             | 
| 1099 | 
            +
                        scores_per_image = scores_per_image[keep]
         | 
| 1100 | 
            +
                        labels_per_image = labels_per_image[keep]
         | 
| 1101 | 
            +
                        mask_pred = mask_pred[keep]
         | 
| 1102 | 
            +
             | 
| 1103 | 
            +
                        if box_pred is not None:
         | 
| 1104 | 
            +
                            box_pred = box_pred[keep]
         | 
| 1105 | 
            +
             | 
| 1106 | 
            +
                    result = Instances(image_size)
         | 
| 1107 | 
            +
                    # mask (before sigmoid)
         | 
| 1108 | 
            +
                    result.pred_masks = (mask_pred > 0).float()
         | 
| 1109 | 
            +
                    # result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
         | 
| 1110 | 
            +
                    # Uncomment the following to get boxes from masks (this is slow)
         | 
| 1111 | 
            +
             | 
| 1112 | 
            +
                    if box_pred is not None:
         | 
| 1113 | 
            +
                        result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
         | 
| 1114 | 
            +
                    else:
         | 
| 1115 | 
            +
                        result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
         | 
| 1116 | 
            +
             | 
| 1117 | 
            +
                    # calculate average mask prob
         | 
| 1118 | 
            +
                    mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
         | 
| 1119 | 
            +
                    result.scores = scores_per_image * mask_scores_per_image
         | 
| 1120 | 
            +
                    result.pred_classes = labels_per_image
         | 
| 1121 | 
            +
             | 
| 1122 | 
            +
                    return result
         | 
| 1123 | 
            +
             | 
| 1124 | 
            +
                def prepare_targets4query(self, targets, images, topk=5):
         | 
| 1125 | 
            +
                    h_pad, w_pad = images.tensor.shape[-2:]
         | 
| 1126 | 
            +
                    new_targets = []
         | 
| 1127 | 
            +
                    new_queries = []
         | 
| 1128 | 
            +
                    for targets_per_image in targets:
         | 
| 1129 | 
            +
                        # we randomly sample maximally topk concepts
         | 
| 1130 | 
            +
                        unique_target_classes = [k for k in set(targets_per_image.gt_classes.tolist())]
         | 
| 1131 | 
            +
                        selected_target_classes = random.sample(unique_target_classes, min(topk, len(unique_target_classes)))
         | 
| 1132 | 
            +
                        new_targets_per_image = []
         | 
| 1133 | 
            +
                        new_queries_per_image = []
         | 
| 1134 | 
            +
                        for clss in selected_target_classes:
         | 
| 1135 | 
            +
                            indices = (targets_per_image.gt_classes == clss).nonzero().view(-1)
         | 
| 1136 | 
            +
                            # pad gt
         | 
| 1137 | 
            +
                            gt_masks = targets_per_image.gt_masks[indices]
         | 
| 1138 | 
            +
                            padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
         | 
| 1139 | 
            +
                            padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
         | 
| 1140 | 
            +
             | 
| 1141 | 
            +
                            # convert class into concept name and then token seq
         | 
| 1142 | 
            +
                            self.sem_seg_head.predictor.lang_encoder.get_text_embeddings([COCO_PANOPTIC_CLASSES[clss]], name='grounding')
         | 
| 1143 | 
            +
                            query = getattr(self.sem_seg_head.predictor.lang_encoder, 'grounding_text_embeddings')
         | 
| 1144 | 
            +
             | 
| 1145 | 
            +
                            new_targets.append(
         | 
| 1146 | 
            +
                                {
         | 
| 1147 | 
            +
                                    "labels": targets_per_image.gt_classes[indices],
         | 
| 1148 | 
            +
                                    "masks": padded_masks,
         | 
| 1149 | 
            +
                                }
         | 
| 1150 | 
            +
                            )
         | 
| 1151 | 
            +
                            new_queries_per_image.append(query)
         | 
| 1152 | 
            +
                        new_queries.append(new_queries_per_image)
         | 
| 1153 | 
            +
             | 
| 1154 | 
            +
                    return new_targets, new_queries
         | 
| 1155 | 
            +
             | 
| 1156 | 
            +
             | 
| 1157 | 
            +
             | 
| 1158 | 
            +
            @register_model
         | 
| 1159 | 
            +
            def get_seem_model(cfg, **kwargs):
         | 
| 1160 | 
            +
                return GeneralizedSEEM(cfg)
         | 
    	
        modeling/architectures/seem_model_v1.py
    ADDED
    
    | @@ -0,0 +1,1179 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # --------------------------------------------------------
         | 
| 2 | 
            +
            # SEEM -- Segment Everything Everywhere All at Once
         | 
| 3 | 
            +
            # Licensed under The Apache License 2.0 [see LICENSE for details]
         | 
| 4 | 
            +
            # Written by Xueyan Zou (xueyan@cs.wisc.edu)
         | 
| 5 | 
            +
            # --------------------------------------------------------
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import random
         | 
| 8 | 
            +
            from typing import Tuple
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import numpy as np
         | 
| 11 | 
            +
            import torch
         | 
| 12 | 
            +
            from torch import nn
         | 
| 13 | 
            +
            from torch.nn import functional as F
         | 
| 14 | 
            +
            from kornia.contrib import distance_transform
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from detectron2.structures import Boxes, ImageList, Instances, BitMasks
         | 
| 17 | 
            +
            from detectron2.utils.memory import retry_if_cuda_oom
         | 
| 18 | 
            +
            from detectron2.data import MetadataCatalog
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from .build import register_model
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from ..utils import configurable, get_class_names, get_iou, Spatial_ImageList
         | 
| 23 | 
            +
            from ..vision.backbone import build_backbone, Backbone
         | 
| 24 | 
            +
            from ..body import build_xdecoder_head
         | 
| 25 | 
            +
            from ..modules import sem_seg_postprocess, SetCriterion, HungarianMatcher, bbox_postprocess
         | 
| 26 | 
            +
            from ..language import build_language_encoder
         | 
| 27 | 
            +
            from ..language.loss import vl_similarity
         | 
| 28 | 
            +
            from utilities.prompt_engineering import prompt_engineering
         | 
| 29 | 
            +
            from utilities.constants import COCO_PANOPTIC_CLASSES, BIOMED_CLASSES
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            class GeneralizedSEEM(nn.Module):
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                @configurable
         | 
| 35 | 
            +
                def __init__(
         | 
| 36 | 
            +
                    self,
         | 
| 37 | 
            +
                    *,
         | 
| 38 | 
            +
                    backbone: Backbone,
         | 
| 39 | 
            +
                    sem_seg_head: nn.Module,
         | 
| 40 | 
            +
                    criterion: nn.Module,
         | 
| 41 | 
            +
                    losses: dict,
         | 
| 42 | 
            +
                    num_queries: int,
         | 
| 43 | 
            +
                    object_mask_threshold: float,
         | 
| 44 | 
            +
                    overlap_threshold: float,
         | 
| 45 | 
            +
                    metadata,
         | 
| 46 | 
            +
                    task_switch: dict,
         | 
| 47 | 
            +
                    phrase_prob: float,
         | 
| 48 | 
            +
                    size_divisibility: int,
         | 
| 49 | 
            +
                    sem_seg_postprocess_before_inference: bool,
         | 
| 50 | 
            +
                    pixel_mean: Tuple[float],
         | 
| 51 | 
            +
                    pixel_std: Tuple[float],
         | 
| 52 | 
            +
                    # inference
         | 
| 53 | 
            +
                    semantic_on: bool,
         | 
| 54 | 
            +
                    panoptic_on: bool,
         | 
| 55 | 
            +
                    instance_on: bool,
         | 
| 56 | 
            +
                    test_topk_per_image: int,
         | 
| 57 | 
            +
                    train_dataset_name: str,
         | 
| 58 | 
            +
                    interactive_mode: str,
         | 
| 59 | 
            +
                    interactive_iter: str,
         | 
| 60 | 
            +
                    dilation_kernel: torch.Tensor,
         | 
| 61 | 
            +
                    train_max_iter: int,
         | 
| 62 | 
            +
                    binary_classes: bool,
         | 
| 63 | 
            +
                    standard_text_for_eval: bool,
         | 
| 64 | 
            +
                ):
         | 
| 65 | 
            +
                    """
         | 
| 66 | 
            +
                    Args:
         | 
| 67 | 
            +
                        backbone: a backbone module, must follow detectron2's backbone interface
         | 
| 68 | 
            +
                        sem_seg_head: a module that predicts semantic segmentation from backbone features
         | 
| 69 | 
            +
                        criterion: a module that defines the loss
         | 
| 70 | 
            +
                        num_queries: int, number of queries
         | 
| 71 | 
            +
                        object_mask_threshold: float, threshold to filter query based on classification score
         | 
| 72 | 
            +
                            for panoptic segmentation inference
         | 
| 73 | 
            +
                        overlap_threshold: overlap threshold used in general inference for panoptic segmentation
         | 
| 74 | 
            +
                        metadata: dataset meta, get `thing` and `stuff` category names for panoptic
         | 
| 75 | 
            +
                            segmentation inference
         | 
| 76 | 
            +
                        size_divisibility: Some backbones require the input height and width to be divisible by a
         | 
| 77 | 
            +
                            specific integer. We can use this to override such requirement.
         | 
| 78 | 
            +
                        sem_seg_postprocess_before_inference: whether to resize the prediction back
         | 
| 79 | 
            +
                            to original input size before semantic segmentation inference or after.
         | 
| 80 | 
            +
                            For high-resolution dataset like Mapillary, resizing predictions before
         | 
| 81 | 
            +
                            inference will cause OOM error.
         | 
| 82 | 
            +
                        pixel_mean, pixel_std: list or tuple with #channels element, representing
         | 
| 83 | 
            +
                            the per-channel mean and std to be used to normalize the input image
         | 
| 84 | 
            +
                        semantic_on: bool, whether to output semantic segmentation prediction
         | 
| 85 | 
            +
                        instance_on: bool, whether to output instance segmentation prediction
         | 
| 86 | 
            +
                        panoptic_on: bool, whether to output panoptic segmentation prediction
         | 
| 87 | 
            +
                        test_topk_per_image: int, instance segmentation parameter, keep topk instances per image
         | 
| 88 | 
            +
                    """
         | 
| 89 | 
            +
                    super().__init__()
         | 
| 90 | 
            +
                    self.backbone = backbone
         | 
| 91 | 
            +
                    self.sem_seg_head = sem_seg_head
         | 
| 92 | 
            +
                    self.criterion = criterion
         | 
| 93 | 
            +
                    self.losses = losses
         | 
| 94 | 
            +
                    self.num_queries = num_queries
         | 
| 95 | 
            +
                    self.overlap_threshold = overlap_threshold
         | 
| 96 | 
            +
                    self.object_mask_threshold = object_mask_threshold
         | 
| 97 | 
            +
                    self.metadata = metadata
         | 
| 98 | 
            +
                    if size_divisibility < 0:
         | 
| 99 | 
            +
                        # use backbone size_divisibility if not set
         | 
| 100 | 
            +
                        size_divisibility = self.backbone.size_divisibility
         | 
| 101 | 
            +
                    self.size_divisibility = size_divisibility
         | 
| 102 | 
            +
                    self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
         | 
| 103 | 
            +
                    self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
         | 
| 104 | 
            +
                    self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    # additional args
         | 
| 107 | 
            +
                    self.semantic_on = semantic_on
         | 
| 108 | 
            +
                    self.instance_on = instance_on
         | 
| 109 | 
            +
                    self.panoptic_on = panoptic_on
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    # caption argument
         | 
| 112 | 
            +
                    self.task_switch = task_switch
         | 
| 113 | 
            +
                    self.phrase_prob = phrase_prob
         | 
| 114 | 
            +
                    self.train_max_iter = train_max_iter
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    self.test_topk_per_image = test_topk_per_image
         | 
| 117 | 
            +
                    self.train_class_names = get_class_names(train_dataset_name)
         | 
| 118 | 
            +
                    if binary_classes:
         | 
| 119 | 
            +
                        self.train_class_names = ['target', 'background']
         | 
| 120 | 
            +
                    self.interactive_mode = interactive_mode
         | 
| 121 | 
            +
                    self.interactive_iter = interactive_iter
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    if not self.semantic_on:
         | 
| 124 | 
            +
                        assert self.sem_seg_postprocess_before_inference
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    self.register_buffer("dilation_kernel", dilation_kernel)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    self.standard_text_for_eval = standard_text_for_eval
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                @classmethod
         | 
| 131 | 
            +
                def from_config(cls, cfg):
         | 
| 132 | 
            +
                    enc_cfg = cfg['MODEL']['ENCODER']
         | 
| 133 | 
            +
                    dec_cfg = cfg['MODEL']['DECODER']
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    # Loss parameters:
         | 
| 136 | 
            +
                    deep_supervision = dec_cfg['DEEP_SUPERVISION']
         | 
| 137 | 
            +
                    no_object_weight = dec_cfg['NO_OBJECT_WEIGHT']
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    # loss weights
         | 
| 140 | 
            +
                    loss_weights = {'mask': {'ce': dec_cfg['CLASS_WEIGHT'], 'dice': dec_cfg['DICE_WEIGHT'], 'bce': dec_cfg['MASK_WEIGHT']},
         | 
| 141 | 
            +
                                    'bbox': {'l1': dec_cfg['BBOX_WEIGHT'], 'giou': dec_cfg['GIOU_WEIGHT']},
         | 
| 142 | 
            +
                                    'spatial': {'ce': dec_cfg['SCLASS_WEIGHT'], 'dice': dec_cfg['SDICE_WEIGHT'], 'bce': dec_cfg['SMASK_WEIGHT']},
         | 
| 143 | 
            +
                                    'grounding': {'ce': dec_cfg['GCLASS_WEIGHT'], 'dice': dec_cfg['GDICE_WEIGHT'], 'bce': dec_cfg['GMASK_WEIGHT']},
         | 
| 144 | 
            +
                                    'openimage': {'ce': dec_cfg['OCLASS_WEIGHT'], 'dice': dec_cfg['ODICE_WEIGHT'], 'bce': dec_cfg['OMASK_WEIGHT']}}
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    openimage_switch = {'grounding': dec_cfg['OPENIMAGE']['GROUNDING'].get('ENABLED', False),
         | 
| 147 | 
            +
                                        'mask': dec_cfg['OPENIMAGE'].get('ENABLED', False)}
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    task_switch = {'bbox': dec_cfg.get('DETECTION', False),
         | 
| 150 | 
            +
                                   'mask': dec_cfg['MASK'].get('ENABLED', True),
         | 
| 151 | 
            +
                                   'spatial': dec_cfg['SPATIAL'].get('ENABLED', False),
         | 
| 152 | 
            +
                                   'grounding': dec_cfg['GROUNDING'].get('ENABLED', False),
         | 
| 153 | 
            +
                                   'openimage': openimage_switch}
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    top_x_layers = {'mask': dec_cfg.get('TOP_MASK_LAYERS', 10),
         | 
| 156 | 
            +
                                    'grounding': dec_cfg.get('TOP_GROUNDING_LAYERS', 10),
         | 
| 157 | 
            +
                                    'openimage': dec_cfg.get('TOP_OPENIMAGE_LAYERS', 10),
         | 
| 158 | 
            +
                                    'spatial': dec_cfg.get('TOP_SPATIAL_LAYERS', 10)}
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    spatial_cost = {"class_weight": dec_cfg['COST_SPATIAL']['CLASS_WEIGHT'],
         | 
| 161 | 
            +
                                    "mask_weight": dec_cfg['COST_SPATIAL']['MASK_WEIGHT'],
         | 
| 162 | 
            +
                                    "dice_weight": dec_cfg['COST_SPATIAL']['DICE_WEIGHT']}
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    extra = {'task_switch': task_switch}
         | 
| 165 | 
            +
                    backbone = build_backbone(cfg)
         | 
| 166 | 
            +
                    lang_encoder = build_language_encoder(cfg)        
         | 
| 167 | 
            +
                    sem_seg_head = build_xdecoder_head(cfg, backbone.output_shape(), lang_encoder, extra=extra)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    # building criterion
         | 
| 170 | 
            +
                    matcher = HungarianMatcher(
         | 
| 171 | 
            +
                        cost_class=loss_weights['mask']['ce'],
         | 
| 172 | 
            +
                        cost_mask=loss_weights['mask']['bce'],
         | 
| 173 | 
            +
                        cost_dice=loss_weights['mask']['dice'],
         | 
| 174 | 
            +
                        num_points=dec_cfg['TRAIN_NUM_POINTS'],
         | 
| 175 | 
            +
                        spatial_cost=spatial_cost,
         | 
| 176 | 
            +
                    )
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    # init weight dict and criterion loss functions.
         | 
| 179 | 
            +
                    losses = {'seg': [], 'openimage': []}
         | 
| 180 | 
            +
                    if task_switch['mask']:
         | 
| 181 | 
            +
                        losses['seg'] += ["labels", "masks"]
         | 
| 182 | 
            +
                    if task_switch['spatial']:
         | 
| 183 | 
            +
                        losses['seg'] += ["spatials"]
         | 
| 184 | 
            +
                    if task_switch['grounding']:
         | 
| 185 | 
            +
                        losses['seg'] += ["groundings"]
         | 
| 186 | 
            +
                    if task_switch['openimage']:
         | 
| 187 | 
            +
                        losses['openimage'] += ["labels_openimage", "masks"]
         | 
| 188 | 
            +
                    if task_switch['openimage']['grounding']:
         | 
| 189 | 
            +
                        losses['openimage'] += ["groundings"]
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    weight_dict = {}
         | 
| 192 | 
            +
                    for key, turn_on in task_switch.items():
         | 
| 193 | 
            +
                        if turn_on:
         | 
| 194 | 
            +
                            if isinstance(loss_weights[key], dict):
         | 
| 195 | 
            +
                                # HACK it should support bbox in the future
         | 
| 196 | 
            +
                                for key_, weight in loss_weights[key].items():
         | 
| 197 | 
            +
                                    weight_dict["loss_{}_{}_0".format(key, key_)] = weight # NOTE: hard code for segmentation that has multiple loss
         | 
| 198 | 
            +
                            else:
         | 
| 199 | 
            +
                                weight_dict["loss_{}_0".format(key)] = loss_weights[key]
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    # generate full weight dict and remove not computed layers. 
         | 
| 202 | 
            +
                    if deep_supervision:
         | 
| 203 | 
            +
                        dec_layers = dec_cfg['DEC_LAYERS']
         | 
| 204 | 
            +
                        aux_weight_dict = {}
         | 
| 205 | 
            +
                        for i in range(dec_layers - 1):
         | 
| 206 | 
            +
                            for k, v in weight_dict.items():
         | 
| 207 | 
            +
                                if (i+1) > (top_x_layers[k.split('_')[1]] - 1):
         | 
| 208 | 
            +
                                    continue
         | 
| 209 | 
            +
                                aux_weight_dict.update({k.replace('_0', f"_{i+1}"): v})
         | 
| 210 | 
            +
                        weight_dict.update(aux_weight_dict)
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                    grd_weight = {'text': dec_cfg['GROUNDING']['TEXT_WEIGHT'], 'class': dec_cfg['GROUNDING']['CLASS_WEIGHT']}
         | 
| 213 | 
            +
                    # generate critenrion for loss function.
         | 
| 214 | 
            +
                    criterion = SetCriterion(
         | 
| 215 | 
            +
                        sem_seg_head.num_classes,
         | 
| 216 | 
            +
                        matcher=matcher,
         | 
| 217 | 
            +
                        weight_dict=weight_dict,
         | 
| 218 | 
            +
                        top_x_layers=top_x_layers,
         | 
| 219 | 
            +
                        eos_coef=no_object_weight,
         | 
| 220 | 
            +
                        losses=[],
         | 
| 221 | 
            +
                        num_points=dec_cfg['TRAIN_NUM_POINTS'],
         | 
| 222 | 
            +
                        oversample_ratio=dec_cfg['OVERSAMPLE_RATIO'],
         | 
| 223 | 
            +
                        importance_sample_ratio=dec_cfg['IMPORTANCE_SAMPLE_RATIO'],
         | 
| 224 | 
            +
                        grounding_weight=grd_weight,
         | 
| 225 | 
            +
                    )
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    # extra logistic
         | 
| 228 | 
            +
                    train_dataset_name = cfg['DATASETS']['TRAIN'][0] # HACK for only one training set.
         | 
| 229 | 
            +
                    train_max_iter = dec_cfg['SPATIAL'].get('MAX_ITER', 3)
         | 
| 230 | 
            +
                    phrase_prob = dec_cfg['CAPTION'].get('PHRASE_PROB', 0.5)
         | 
| 231 | 
            +
                    interactive_mode = cfg['STROKE_SAMPLER']['EVAL']['MODE']
         | 
| 232 | 
            +
                    interactive_iter = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                    dilation = 3
         | 
| 235 | 
            +
                    dilation_kernel = torch.ones((1, 1, dilation, dilation), device=torch.cuda.current_device())
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    return {
         | 
| 238 | 
            +
                        "backbone": backbone,
         | 
| 239 | 
            +
                        "sem_seg_head": sem_seg_head,
         | 
| 240 | 
            +
                        "criterion": criterion,
         | 
| 241 | 
            +
                        "losses": losses,
         | 
| 242 | 
            +
                        "num_queries": dec_cfg['NUM_OBJECT_QUERIES'],
         | 
| 243 | 
            +
                        "object_mask_threshold": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'],
         | 
| 244 | 
            +
                        "overlap_threshold": dec_cfg['TEST']['OVERLAP_THRESHOLD'],
         | 
| 245 | 
            +
                        "metadata": MetadataCatalog.get(cfg['DATASETS']['TRAIN'][0]),
         | 
| 246 | 
            +
                        "size_divisibility": dec_cfg['SIZE_DIVISIBILITY'],
         | 
| 247 | 
            +
                        "sem_seg_postprocess_before_inference": (
         | 
| 248 | 
            +
                            dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE']
         | 
| 249 | 
            +
                            or dec_cfg['TEST']['PANOPTIC_ON']
         | 
| 250 | 
            +
                            or dec_cfg['TEST']['INSTANCE_ON']
         | 
| 251 | 
            +
                        ),
         | 
| 252 | 
            +
                        "pixel_mean": cfg['INPUT']['PIXEL_MEAN'],
         | 
| 253 | 
            +
                        "pixel_std": cfg['INPUT']['PIXEL_STD'],
         | 
| 254 | 
            +
                        "task_switch": task_switch,
         | 
| 255 | 
            +
                        "phrase_prob": phrase_prob,
         | 
| 256 | 
            +
                        # inference
         | 
| 257 | 
            +
                        "semantic_on": dec_cfg['TEST']['SEMANTIC_ON'],
         | 
| 258 | 
            +
                        "instance_on": dec_cfg['TEST']['INSTANCE_ON'],
         | 
| 259 | 
            +
                        "panoptic_on": dec_cfg['TEST']['PANOPTIC_ON'],
         | 
| 260 | 
            +
                        "test_topk_per_image": cfg['TEST']['DETECTIONS_PER_IMAGE'],
         | 
| 261 | 
            +
                        "train_dataset_name": train_dataset_name,
         | 
| 262 | 
            +
                        "interactive_mode": interactive_mode,
         | 
| 263 | 
            +
                        "interactive_iter": interactive_iter,
         | 
| 264 | 
            +
                        "dilation_kernel": dilation_kernel,
         | 
| 265 | 
            +
                        "train_max_iter": train_max_iter,
         | 
| 266 | 
            +
                        "binary_classes": enc_cfg['BINARY_CLASSES'],
         | 
| 267 | 
            +
                        "standard_text_for_eval": cfg['STANDARD_TEXT_FOR_EVAL'],
         | 
| 268 | 
            +
                    }
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                @property
         | 
| 271 | 
            +
                def device(self):
         | 
| 272 | 
            +
                    return self.pixel_mean.device
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                def forward(self, batched_inputs, mode='default'):
         | 
| 275 | 
            +
                    """
         | 
| 276 | 
            +
                    Args:
         | 
| 277 | 
            +
                        batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
         | 
| 278 | 
            +
                            Each item in the list contains the inputs for one image.
         | 
| 279 | 
            +
                            For now, each item in the list is a dict that contains:
         | 
| 280 | 
            +
                               * "image": Tensor, image in (C, H, W) format.
         | 
| 281 | 
            +
                               * "instances": per-region ground truth
         | 
| 282 | 
            +
                               * Other information that's included in the original dicts, such as:
         | 
| 283 | 
            +
                                 "height", "width" (int): the output resolution of the model (may be different
         | 
| 284 | 
            +
                                 from input resolution), used in inference.
         | 
| 285 | 
            +
                    Returns:
         | 
| 286 | 
            +
                        list[dict]:
         | 
| 287 | 
            +
                            each dict has the results for one image. The dict contains the following keys:
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                            * "sem_seg":
         | 
| 290 | 
            +
                                A Tensor that represents the
         | 
| 291 | 
            +
                                per-pixel segmentation prediced by the head.
         | 
| 292 | 
            +
                                The prediction has shape KxHxW that represents the logits of
         | 
| 293 | 
            +
                                each class for each pixel.
         | 
| 294 | 
            +
                            * "panoptic_seg":
         | 
| 295 | 
            +
                                A tuple that represent panoptic output
         | 
| 296 | 
            +
                                panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
         | 
| 297 | 
            +
                                segments_info (list[dict]): Describe each segment in `panoptic_seg`.
         | 
| 298 | 
            +
                                    Each dict contains keys "id", "category_id", "isthing".
         | 
| 299 | 
            +
                    """
         | 
| 300 | 
            +
                    if self.training:
         | 
| 301 | 
            +
                        losses = {}
         | 
| 302 | 
            +
                        if self.task_switch['mask'] or self.task_switch['grounding'] or self.task_switch['spatial']:
         | 
| 303 | 
            +
                            losses_seg = self.forward_seg(batched_inputs)
         | 
| 304 | 
            +
                            losses.update(losses_seg)
         | 
| 305 | 
            +
                        if self.task_switch['openimage'] and self.task_switch['openimage']['mask']:
         | 
| 306 | 
            +
                            losses_openimage = self.forward_openimage(batched_inputs['openimage'])
         | 
| 307 | 
            +
                            losses_openimage = {key.replace('mask', 'openimage'):value for key, value in losses_openimage.items()}
         | 
| 308 | 
            +
                            losses_openimage = {key.replace('grounding', 'grounding_openimage'):value for key, value in losses_openimage.items()}
         | 
| 309 | 
            +
                            losses.update(losses_openimage)
         | 
| 310 | 
            +
                        for k in list(losses.keys()):
         | 
| 311 | 
            +
                            if k in self.criterion.weight_dict:
         | 
| 312 | 
            +
                                losses[k] *= self.criterion.weight_dict[k]
         | 
| 313 | 
            +
                            else: # remove this loss if not specified in `weight_dict`
         | 
| 314 | 
            +
                                losses.pop(k)
         | 
| 315 | 
            +
                        return losses
         | 
| 316 | 
            +
                    else:
         | 
| 317 | 
            +
                        if mode == 'interactive':
         | 
| 318 | 
            +
                            return self.evaluate_interactive(batched_inputs)
         | 
| 319 | 
            +
                        elif mode == 'interactive_grounding':
         | 
| 320 | 
            +
                            return self.evaluate_interactive_grounding(batched_inputs)
         | 
| 321 | 
            +
                        elif mode == 'grounding_spatial':
         | 
| 322 | 
            +
                            return self.evaluate_grounding_sptial(batched_inputs, mode)
         | 
| 323 | 
            +
                        elif mode in ['grounding_phrasecut', 'grounding_refcoco']:
         | 
| 324 | 
            +
                            return self.evaluate_grounding(batched_inputs, mode)
         | 
| 325 | 
            +
                        else:
         | 
| 326 | 
            +
                            return self.evaluate(batched_inputs)
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                    
         | 
| 329 | 
            +
                def forward_seg(self, batched_inputs):
         | 
| 330 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 331 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 332 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 333 | 
            +
                    self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(self.train_class_names, is_eval=False)
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                    extra = {}
         | 
| 336 | 
            +
                    # mask classification target
         | 
| 337 | 
            +
                    if "instances" in batched_inputs[0]:
         | 
| 338 | 
            +
                        # input bounding box is checked to be correct.
         | 
| 339 | 
            +
                        targets = self.prepare_targets(batched_inputs, images)
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                        if self.task_switch['grounding']:
         | 
| 342 | 
            +
                            grounding_tokens = [x['grounding_query_embs'] for x in targets] # need to pad for more than one grounding token
         | 
| 343 | 
            +
                            grounding_tokens = nn.utils.rnn.pad_sequence(grounding_tokens, padding_value=-1)
         | 
| 344 | 
            +
                            non_zero_query_mask = (grounding_tokens.sum(dim=-1) == -grounding_tokens.shape[-1])
         | 
| 345 | 
            +
                            grounding_tokens[non_zero_query_mask] = 0
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                            extra['grounding_tokens'] = grounding_tokens
         | 
| 348 | 
            +
                            extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                        if self.task_switch['spatial']:
         | 
| 351 | 
            +
                            pos_masks = [x['spatial_query']['rand_shape'].to(self.device) for x in batched_inputs]
         | 
| 352 | 
            +
                            neg_masks = [(x['spatial_query']['rand_shape'].to(self.device) & False) for x in batched_inputs]
         | 
| 353 | 
            +
                            fp_masks = nn.utils.rnn.pad_sequence([(x['spatial_query']['rand_shape'].to(self.device) & False) for x in batched_inputs], padding_value=False, batch_first=True)
         | 
| 354 | 
            +
                            extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks, 'false_positive_mask': fp_masks})
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                    features = self.backbone(images.tensor)
         | 
| 357 | 
            +
                    mask_features, _, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                    # forward spatial only without gradient
         | 
| 360 | 
            +
                    if self.task_switch['spatial']:
         | 
| 361 | 
            +
                        with torch.no_grad():
         | 
| 362 | 
            +
                            # generate random integeter between [0,3]
         | 
| 363 | 
            +
                            rand_iter_num = random.randint(0, self.train_max_iter)
         | 
| 364 | 
            +
                            for i in range(rand_iter_num):
         | 
| 365 | 
            +
                                outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, extra=extra, task='spatial')
         | 
| 366 | 
            +
                                extra.update(outputs)
         | 
| 367 | 
            +
                                extra.update(self.prepare_next_spaital_mask(extra, batched_inputs))
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                    outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, extra=extra, task='seg')
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                    extra = {'lang_logit': self.sem_seg_head.predictor.lang_encoder.logit_scale,
         | 
| 372 | 
            +
                             'class_embeddings': getattr(self.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('default')),
         | 
| 373 | 
            +
                             'false_positive_mask': extra['false_positive_mask']}
         | 
| 374 | 
            +
                    # bipartite matching-based loss
         | 
| 375 | 
            +
                    self.criterion.losses = self.losses['seg'] # seg criterion losses
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                    if self.task_switch['mask']:
         | 
| 378 | 
            +
                        losses = self.criterion(outputs, targets, extra)
         | 
| 379 | 
            +
                    else:
         | 
| 380 | 
            +
                        losses = self.criterion.forward_vlp(outputs, targets, extra)
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                    del outputs
         | 
| 383 | 
            +
                    return losses
         | 
| 384 | 
            +
             | 
| 385 | 
            +
                def evaluate(self, batched_inputs):
         | 
| 386 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 387 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 388 | 
            +
                    
         | 
| 389 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 390 | 
            +
                    img_bs = images.tensor.shape[0]
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                    targets = targets_grounding = queries_grounding = None
         | 
| 393 | 
            +
                    features = self.backbone(images.tensor)
         | 
| 394 | 
            +
                    outputs = self.sem_seg_head(features, target_queries=queries_grounding)
         | 
| 395 | 
            +
             | 
| 396 | 
            +
                    mask_cls_results = outputs["pred_logits"]
         | 
| 397 | 
            +
                    mask_pred_results = outputs["pred_masks"]
         | 
| 398 | 
            +
                    box_pred_results = outputs["pred_boxes"] if self.task_switch['bbox'] else [None for i in range(len(mask_pred_results))]
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                    # upsample masks
         | 
| 401 | 
            +
                    mask_pred_results = F.interpolate(
         | 
| 402 | 
            +
                        mask_pred_results,
         | 
| 403 | 
            +
                        size=(images.tensor.shape[-2], images.tensor.shape[-1]),
         | 
| 404 | 
            +
                        mode="bilinear",
         | 
| 405 | 
            +
                        align_corners=False,
         | 
| 406 | 
            +
                    )
         | 
| 407 | 
            +
             | 
| 408 | 
            +
                    input_size = mask_pred_results.shape[-2:]
         | 
| 409 | 
            +
                    del outputs
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                    processed_results = []
         | 
| 412 | 
            +
                    for mask_cls_result, mask_pred_result, box_pred_result, input_per_image, image_size in zip(
         | 
| 413 | 
            +
                        mask_cls_results, mask_pred_results, box_pred_results, batched_inputs, images.image_sizes
         | 
| 414 | 
            +
                    ):
         | 
| 415 | 
            +
                        height = input_per_image.get("height", image_size[0])
         | 
| 416 | 
            +
                        width = input_per_image.get("width", image_size[1])
         | 
| 417 | 
            +
                        processed_results.append({})
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                        if self.sem_seg_postprocess_before_inference:
         | 
| 420 | 
            +
                            mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
         | 
| 421 | 
            +
                                mask_pred_result, image_size, height, width
         | 
| 422 | 
            +
                            )
         | 
| 423 | 
            +
                            mask_cls_result = mask_cls_result.to(mask_pred_result)
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                        # semantic segmentation inference
         | 
| 426 | 
            +
                        if self.semantic_on:
         | 
| 427 | 
            +
                            r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result)
         | 
| 428 | 
            +
                            if not self.sem_seg_postprocess_before_inference:
         | 
| 429 | 
            +
                                r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width)
         | 
| 430 | 
            +
                            processed_results[-1]["sem_seg"] = r
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                        # panoptic segmentation inference
         | 
| 433 | 
            +
                        if self.panoptic_on:
         | 
| 434 | 
            +
                            panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result)
         | 
| 435 | 
            +
                            processed_results[-1]["panoptic_seg"] = panoptic_r
         | 
| 436 | 
            +
                        
         | 
| 437 | 
            +
                        # instance segmentation inference
         | 
| 438 | 
            +
                        if self.instance_on:
         | 
| 439 | 
            +
                            if self.task_switch['bbox']:
         | 
| 440 | 
            +
                                box_pred_result = bbox_postprocess(box_pred_result, input_size, image_size, height, width)
         | 
| 441 | 
            +
                            instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result, box_pred_result)
         | 
| 442 | 
            +
                            processed_results[-1]["instances"] = instance_r
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                    return processed_results
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                def evaluate_interactive(self, batched_inputs):
         | 
| 447 | 
            +
                    assert self.task_switch['spatial']
         | 
| 448 | 
            +
                    assert 'spatial_query' in batched_inputs[0]
         | 
| 449 | 
            +
                    assert len(batched_inputs) == 1, "only support batch size equal to 1"
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 452 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 453 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 454 | 
            +
                    img_bs = images.tensor.shape[0]
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                    targets = targets_grounding = queries_grounding = None
         | 
| 457 | 
            +
                    extra = {}
         | 
| 458 | 
            +
             | 
| 459 | 
            +
                    features = self.backbone(images.tensor)
         | 
| 460 | 
            +
                    mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                    image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                    all_batch_shape_iou = []
         | 
| 465 | 
            +
                    pred_smask_pointer = None
         | 
| 466 | 
            +
                    prev_smask_pointer = None
         | 
| 467 | 
            +
                    pred_smask_all = None
         | 
| 468 | 
            +
             | 
| 469 | 
            +
                    # visualization code
         | 
| 470 | 
            +
                    # v_pred_mask = []
         | 
| 471 | 
            +
                    # v_pos_mask = []
         | 
| 472 | 
            +
                    # v_neg_mask = []
         | 
| 473 | 
            +
                    # v_gt_mask = batched_inputs[0]['spatial_query']['gt_masks'][0]
         | 
| 474 | 
            +
                    query_index = self.sem_seg_head.predictor.query_index
         | 
| 475 | 
            +
                    if self.interactive_mode in ['best', 'best_random']:
         | 
| 476 | 
            +
                        pos_masks = [x['spatial_query']['rand_shape'].to(self.device)[:,0] for x in batched_inputs]
         | 
| 477 | 
            +
                        pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
         | 
| 478 | 
            +
             | 
| 479 | 
            +
                        neg_masks = [(x['spatial_query']['rand_shape'].to(self.device) & False)[:,0] for x in batched_inputs]
         | 
| 480 | 
            +
                
         | 
| 481 | 
            +
                        neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
         | 
| 482 | 
            +
                        extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
         | 
| 483 | 
            +
                    elif self.interactive_mode == 'random':
         | 
| 484 | 
            +
                        assert False, "interactive mode not correctly implemented"
         | 
| 485 | 
            +
                        pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==1).unbind(0)
         | 
| 486 | 
            +
                        pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                        neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==-1).unbind(0)
         | 
| 489 | 
            +
                        neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor
         | 
| 490 | 
            +
                        extra.update({'spatial_query_pos_mask': pos_masks[:,0:1].unbind(), 'spatial_query_neg_mask': neg_masks[:,0:1].unbind()})
         | 
| 491 | 
            +
                    else:
         | 
| 492 | 
            +
                        assert False, "invalid interactive mode"
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                    for i in range(self.interactive_iter):
         | 
| 495 | 
            +
                        # v_pos_mask += [extra['spatial_query_pos_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
         | 
| 496 | 
            +
                        # v_neg_mask += [extra['spatial_query_neg_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
         | 
| 497 | 
            +
                        outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
         | 
| 498 | 
            +
                        extra.update(outputs)
         | 
| 499 | 
            +
                        pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bilinear')
         | 
| 500 | 
            +
                        # v_pred_mask += [(pred_smask[0,0][:image_sizes[0][0],:image_sizes[0][1]].sigmoid() > 0.5).float().cpu().numpy()]
         | 
| 501 | 
            +
             | 
| 502 | 
            +
                        s = image_sizes[0]
         | 
| 503 | 
            +
                        b = batched_inputs[0]
         | 
| 504 | 
            +
                        pred_smask_all = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bilinear')[0].sigmoid() > 0.5
         | 
| 505 | 
            +
                        gt_smask = b['gt_masks_orisize']
         | 
| 506 | 
            +
                        ious = get_iou(gt_smask, pred_smask_all)
         | 
| 507 | 
            +
                        all_batch_shape_iou += [ious]
         | 
| 508 | 
            +
                        if (ious > 0.9).sum() == len(ious):
         | 
| 509 | 
            +
                            all_batch_shape_iou += [ious for j in range(self.interactive_iter-i-1)]
         | 
| 510 | 
            +
                            break
         | 
| 511 | 
            +
                        if self.interactive_mode in ['best', 'best_random']:
         | 
| 512 | 
            +
                            extra.update(self.prepare_next_spaital_mask(extra, batched_inputs, mode=self.interactive_mode))
         | 
| 513 | 
            +
                        elif self.interactive_mode == 'random':
         | 
| 514 | 
            +
                            extra.update({'spatial_query_pos_mask': pos_masks[:,i+1:i+2].unbind(), 'spatial_query_neg_mask': neg_masks[:,i+1:i+2].unbind()})
         | 
| 515 | 
            +
                        else:
         | 
| 516 | 
            +
                            assert False, "invalid interactive mode"
         | 
| 517 | 
            +
                    all_batch_shape_iou = torch.stack(all_batch_shape_iou)
         | 
| 518 | 
            +
                    processed_results = [{"mask_iou": all_batch_shape_iou[:,i]} for i in range(len(all_batch_shape_iou[0]))]
         | 
| 519 | 
            +
             | 
| 520 | 
            +
                    return processed_results
         | 
| 521 | 
            +
             | 
| 522 | 
            +
                def evaluate_interactive_single(self, batched_inputs, extra={}):
         | 
| 523 | 
            +
                    assert self.task_switch['spatial']
         | 
| 524 | 
            +
                    assert 'spatial_query' in batched_inputs[0]
         | 
| 525 | 
            +
                    assert len(batched_inputs) == 1, "only support batch size equal to 1"
         | 
| 526 | 
            +
             | 
| 527 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 528 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 529 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 530 | 
            +
                    img_bs = images.tensor.shape[0]
         | 
| 531 | 
            +
             | 
| 532 | 
            +
                    targets = targets_grounding = queries_grounding = None
         | 
| 533 | 
            +
             | 
| 534 | 
            +
                    features = self.backbone(images.tensor)
         | 
| 535 | 
            +
                    mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
         | 
| 536 | 
            +
             | 
| 537 | 
            +
                    image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
         | 
| 538 | 
            +
                    nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
         | 
| 539 | 
            +
                    multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
         | 
| 540 | 
            +
                    mask_features = mask_features.repeat(nm,1,1,1)
         | 
| 541 | 
            +
             | 
| 542 | 
            +
                    outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
         | 
| 543 | 
            +
                    pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bicubic')
         | 
| 544 | 
            +
             | 
| 545 | 
            +
                    s = image_sizes[0]
         | 
| 546 | 
            +
                    b = batched_inputs[0]
         | 
| 547 | 
            +
                    pred_smask_ori = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bicubic')[:,0].sigmoid() > 0.5
         | 
| 548 | 
            +
                    pred_smask_batch = pred_smask[:,:,:s[0],:s[1]].sigmoid() > 0.5
         | 
| 549 | 
            +
                    ious = []
         | 
| 550 | 
            +
                    if 'gt_masks_orisize' in b:
         | 
| 551 | 
            +
                        gt_smask = b['gt_masks_orisize'].to(pred_smask_ori.device)
         | 
| 552 | 
            +
                        ious = get_iou(gt_smask, pred_smask_ori)
         | 
| 553 | 
            +
                    processed_results = [{"mask_iou": ious, 'pred_mask_ori': pred_smask_ori, 'pred_mask_batch': pred_smask_batch}]
         | 
| 554 | 
            +
                    return processed_results
         | 
| 555 | 
            +
             | 
| 556 | 
            +
                def evaluate_interactive_grounding(self, batched_inputs):
         | 
| 557 | 
            +
                    assert self.task_switch['spatial']
         | 
| 558 | 
            +
                    assert 'spatial_query' in batched_inputs[0]
         | 
| 559 | 
            +
                    assert len(batched_inputs) == 1, "only support batch size equal to 1"
         | 
| 560 | 
            +
             | 
| 561 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 562 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 563 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 564 | 
            +
                    img_bs = images.tensor.shape[0]
         | 
| 565 | 
            +
             | 
| 566 | 
            +
                    targets = targets_grounding = queries_grounding = None
         | 
| 567 | 
            +
                    extra = {}
         | 
| 568 | 
            +
             | 
| 569 | 
            +
                    features = self.backbone(images.tensor)
         | 
| 570 | 
            +
                    mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
         | 
| 571 | 
            +
             | 
| 572 | 
            +
                    image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
         | 
| 573 | 
            +
                    nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
         | 
| 574 | 
            +
                    multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
         | 
| 575 | 
            +
                    mask_features = mask_features.repeat(nm,1,1,1)
         | 
| 576 | 
            +
             | 
| 577 | 
            +
                    all_batch_shape_iou = []
         | 
| 578 | 
            +
                    pred_smask_pointer = None
         | 
| 579 | 
            +
                    prev_smask_pointer = None
         | 
| 580 | 
            +
                    pred_smask_all = None
         | 
| 581 | 
            +
             | 
| 582 | 
            +
                    # visualization code
         | 
| 583 | 
            +
                    # v_pred_mask = []
         | 
| 584 | 
            +
                    # v_pos_mask = []
         | 
| 585 | 
            +
                    # v_neg_mask = []
         | 
| 586 | 
            +
                    # v_gt_mask = batched_inputs[0]['spatial_query']['gt_masks'][0]
         | 
| 587 | 
            +
                    query_index = self.sem_seg_head.predictor.query_index
         | 
| 588 | 
            +
                    if self.interactive_mode in ['best', 'best_random']:
         | 
| 589 | 
            +
                        pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
         | 
| 590 | 
            +
                        pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
         | 
| 591 | 
            +
             | 
| 592 | 
            +
                        neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
         | 
| 593 | 
            +
                        neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
         | 
| 594 | 
            +
                        extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
         | 
| 595 | 
            +
                    elif self.interactive_mode == 'random':
         | 
| 596 | 
            +
                        pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==1).unbind(0)
         | 
| 597 | 
            +
                        pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
         | 
| 598 | 
            +
             | 
| 599 | 
            +
                        neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==-1).unbind(0)
         | 
| 600 | 
            +
                        neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor
         | 
| 601 | 
            +
                        extra.update({'spatial_query_pos_mask': pos_masks[:,0:1].unbind(), 'spatial_query_neg_mask': neg_masks[:,0:1].unbind()})
         | 
| 602 | 
            +
                    else:
         | 
| 603 | 
            +
                        assert False, "invalid interactive mode"
         | 
| 604 | 
            +
             | 
| 605 | 
            +
                    grd_texts = batched_inputs[0]['classes']
         | 
| 606 | 
            +
                    gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
         | 
| 607 | 
            +
                    token_emb = gtext['token_emb']
         | 
| 608 | 
            +
                    tokens = gtext['tokens']
         | 
| 609 | 
            +
                    query_emb = nn.utils.rnn.pad_sequence([_token_emb[_tokens.bool()] for _token_emb, _tokens in zip(token_emb, tokens['attention_mask'])], padding_value=-1)
         | 
| 610 | 
            +
                    non_zero_query_mask = (query_emb.sum(dim=-1) < 0)
         | 
| 611 | 
            +
             | 
| 612 | 
            +
                    extra['grounding_tokens'] = query_emb
         | 
| 613 | 
            +
                    extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
         | 
| 614 | 
            +
             | 
| 615 | 
            +
                    for i in range(self.interactive_iter):
         | 
| 616 | 
            +
                        # v_pos_mask += [extra['spatial_query_pos_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
         | 
| 617 | 
            +
                        # v_neg_mask += [extra['spatial_query_neg_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
         | 
| 618 | 
            +
                        outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
         | 
| 619 | 
            +
                        extra.update(outputs)
         | 
| 620 | 
            +
                        pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bilinear')
         | 
| 621 | 
            +
                        # v_pred_mask += [(pred_smask[0,0][:image_sizes[0][0],:image_sizes[0][1]].sigmoid() > 0.5).float().cpu().numpy()]
         | 
| 622 | 
            +
             | 
| 623 | 
            +
                        s = image_sizes[0]
         | 
| 624 | 
            +
                        b = batched_inputs[0]
         | 
| 625 | 
            +
                        pred_smask_all = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bilinear')[:,0].sigmoid() > 0.5
         | 
| 626 | 
            +
                        gt_smask = b['gt_masks_orisize']
         | 
| 627 | 
            +
                        ious = get_iou(gt_smask, pred_smask_all)
         | 
| 628 | 
            +
                        all_batch_shape_iou += [ious]
         | 
| 629 | 
            +
                        if (ious > 0.9).sum() == len(ious):
         | 
| 630 | 
            +
                            all_batch_shape_iou += [ious for j in range(self.interactive_iter-i-1)]
         | 
| 631 | 
            +
                            break
         | 
| 632 | 
            +
                        if self.interactive_mode in ['best', 'best_random']:
         | 
| 633 | 
            +
                            extra.update(self.prepare_next_spaital_mask(extra, batched_inputs, mode=self.interactive_mode))
         | 
| 634 | 
            +
                        elif self.interactive_mode == 'random':
         | 
| 635 | 
            +
                            extra.update({'spatial_query_pos_mask': pos_masks[:,i+1:i+2].unbind(), 'spatial_query_neg_mask': neg_masks[:,i+1:i+2].unbind()})
         | 
| 636 | 
            +
                        else:
         | 
| 637 | 
            +
                            assert False, "invalid interactive mode"
         | 
| 638 | 
            +
                    all_batch_shape_iou = torch.stack(all_batch_shape_iou)
         | 
| 639 | 
            +
                    processed_results = [{"mask_iou": all_batch_shape_iou[:,i]} for i in range(len(all_batch_shape_iou[0]))]
         | 
| 640 | 
            +
             | 
| 641 | 
            +
                    # visualization
         | 
| 642 | 
            +
                    # VL.step()
         | 
| 643 | 
            +
                    # import cv2
         | 
| 644 | 
            +
                    # v_masks = []
         | 
| 645 | 
            +
                    # v_pos_masks = []
         | 
| 646 | 
            +
                    # v_neg_masks = []
         | 
| 647 | 
            +
                    # txt = []
         | 
| 648 | 
            +
             | 
| 649 | 
            +
                    # img = batched_inputs[0]['image'].permute(1,2,0).cpu().numpy()
         | 
| 650 | 
            +
                    # mask_img = VL.overlay_single_mask_to_image(img[:,:,::-1], v_gt_mask.cpu().float().numpy())
         | 
| 651 | 
            +
                    # acc_pos_mask = np.zeros(v_pos_mask[0].shape)
         | 
| 652 | 
            +
                    # acc_neg_mask = np.zeros(v_neg_mask[0].shape)
         | 
| 653 | 
            +
                    # for x,y,z,iou in zip(v_pos_mask, v_neg_mask, v_pred_mask, all_batch_shape_iou):
         | 
| 654 | 
            +
                    #     # dilate x,y
         | 
| 655 | 
            +
                    #     x = cv2.dilate(x, np.ones((5,5), np.uint8), iterations=3)
         | 
| 656 | 
            +
                    #     y = cv2.dilate(y, np.ones((5,5), np.uint8), iterations=3)
         | 
| 657 | 
            +
                    #     acc_pos_mask += x
         | 
| 658 | 
            +
                    #     acc_neg_mask += y
         | 
| 659 | 
            +
             | 
| 660 | 
            +
                    #     v_masks += [z]
         | 
| 661 | 
            +
                    #     v_pos_masks += [acc_pos_mask.clip(0,1)]
         | 
| 662 | 
            +
                    #     v_neg_masks += [acc_neg_mask.clip(0,1)]
         | 
| 663 | 
            +
                    #     txt += ["pred_{}".format(str(iou[0].item())[0:5])]
         | 
| 664 | 
            +
             | 
| 665 | 
            +
                    # VL.add_image(img[:,:,::-1])
         | 
| 666 | 
            +
                    # VL.insert(mask_img, "gt_mask")
         | 
| 667 | 
            +
                    # VL.overlay_obj_mask_to_image_withposneg(img[:,:,::-1], v_masks, v_pos_masks, v_neg_masks, txt, max_len=20)
         | 
| 668 | 
            +
                    return processed_results
         | 
| 669 | 
            +
             | 
| 670 | 
            +
                def evaluate_referring_image(self, batched_inputs, extra={}):
         | 
| 671 | 
            +
                    assert self.task_switch['spatial']
         | 
| 672 | 
            +
                    assert len(batched_inputs) == 1, "only support batch size equal to 1"
         | 
| 673 | 
            +
                    assert self.interactive_mode == 'best'
         | 
| 674 | 
            +
             | 
| 675 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 676 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 677 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 678 | 
            +
                    img_bs = images.tensor.shape[0]
         | 
| 679 | 
            +
             | 
| 680 | 
            +
                    targets = targets_grounding = queries_grounding = None
         | 
| 681 | 
            +
                    features = self.backbone(images.tensor)
         | 
| 682 | 
            +
                    mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
         | 
| 683 | 
            +
             | 
| 684 | 
            +
                    if 'spatial_query' in batched_inputs[0]:
         | 
| 685 | 
            +
                        image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
         | 
| 686 | 
            +
                        nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
         | 
| 687 | 
            +
                        multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
         | 
| 688 | 
            +
                        mask_features = mask_features.repeat(nm,1,1,1)
         | 
| 689 | 
            +
             | 
| 690 | 
            +
                        query_index = self.sem_seg_head.predictor.query_index
         | 
| 691 | 
            +
                        pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
         | 
| 692 | 
            +
                        pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
         | 
| 693 | 
            +
             | 
| 694 | 
            +
                        neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
         | 
| 695 | 
            +
                        neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
         | 
| 696 | 
            +
                        extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
         | 
| 697 | 
            +
             | 
| 698 | 
            +
                    outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='refimg')
         | 
| 699 | 
            +
                    return outputs, images.tensor.shape
         | 
| 700 | 
            +
             | 
| 701 | 
            +
                def evaluate_grounding(self, batched_inputs, mode):
         | 
| 702 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 703 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 704 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 705 | 
            +
                    assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
         | 
| 706 | 
            +
             | 
| 707 | 
            +
                    extra = {}
         | 
| 708 | 
            +
                    # mask_pred_results = []
         | 
| 709 | 
            +
                    # for idx, batch_per_image in enumerate(batched_inputs):
         | 
| 710 | 
            +
                    #     grd_texts = batch_per_image['groundings']['texts']
         | 
| 711 | 
            +
                    #     grd_masks = []
         | 
| 712 | 
            +
                    #     for anno_text in grd_texts:
         | 
| 713 | 
            +
                    #         gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
         | 
| 714 | 
            +
                    #         token_emb = gtext['token_emb']
         | 
| 715 | 
            +
                    #         tokens = gtext['tokens']
         | 
| 716 | 
            +
                        
         | 
| 717 | 
            +
                    #         grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
         | 
| 718 | 
            +
                    #         extra['grounding_tokens'] = grd_emb[:,None]
         | 
| 719 | 
            +
             | 
| 720 | 
            +
                    #         assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
         | 
| 721 | 
            +
                    #         features = self.backbone(images.tensor)
         | 
| 722 | 
            +
                    #         outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
         | 
| 723 | 
            +
                            
         | 
| 724 | 
            +
                    #         pred_gmasks = outputs['pred_masks'][idx,self.num_queries:2*self.num_queries-1]
         | 
| 725 | 
            +
                    #         v_emb = outputs['pred_captions'][idx,self.num_queries:2*self.num_queries-1]
         | 
| 726 | 
            +
                    #         t_emb = grd_emb[-1:]
         | 
| 727 | 
            +
             | 
| 728 | 
            +
                    #         t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 729 | 
            +
                    #         v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)            
         | 
| 730 | 
            +
             | 
| 731 | 
            +
                    #         temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
         | 
| 732 | 
            +
                    #         out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
         | 
| 733 | 
            +
                            
         | 
| 734 | 
            +
                    #         matched_id = out_prob.max(0)[1]
         | 
| 735 | 
            +
                    #         grd_masks += [pred_gmasks[matched_id,:,:]]
         | 
| 736 | 
            +
                    #     mask_pred_results += [torch.cat(grd_masks)]
         | 
| 737 | 
            +
             | 
| 738 | 
            +
                    # comment for multi object inference.
         | 
| 739 | 
            +
                    mask_pred_results = []
         | 
| 740 | 
            +
                    for idx, batch_per_image in enumerate(batched_inputs):
         | 
| 741 | 
            +
                        grd_texts = batch_per_image['groundings']['texts']
         | 
| 742 | 
            +
                        if self.standard_text_for_eval:
         | 
| 743 | 
            +
                            standard_texts = []
         | 
| 744 | 
            +
                            for grd in batch_per_image['grounding_info']:
         | 
| 745 | 
            +
                                mask_file = grd['mask_file'].split('.')[0].split('/')[-1]
         | 
| 746 | 
            +
                                target = mask_file.split('_')[-1].replace('+', ' ')
         | 
| 747 | 
            +
                                site = mask_file.split('_')[-2].replace('+', ' ')
         | 
| 748 | 
            +
                                modality = mask_file.split('_')[-3].replace('+', ' ')
         | 
| 749 | 
            +
                                standard_texts.append(f'{target} in {site} {modality}')
         | 
| 750 | 
            +
                            grd_texts = standard_texts
         | 
| 751 | 
            +
                            batch_per_image['groundings']['texts'] = standard_texts
         | 
| 752 | 
            +
             | 
| 753 | 
            +
             | 
| 754 | 
            +
                        gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
         | 
| 755 | 
            +
                        token_emb = gtext['token_emb']
         | 
| 756 | 
            +
                        tokens = gtext['tokens']
         | 
| 757 | 
            +
                        query_emb = token_emb[tokens['attention_mask'].bool()]
         | 
| 758 | 
            +
                        non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
         | 
| 759 | 
            +
             | 
| 760 | 
            +
                        extra['grounding_tokens'] = query_emb[:,None]
         | 
| 761 | 
            +
                        extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
         | 
| 762 | 
            +
             | 
| 763 | 
            +
                        features = self.backbone(images.tensor)
         | 
| 764 | 
            +
                        outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
         | 
| 765 | 
            +
             | 
| 766 | 
            +
                        pred_gmasks = outputs['pred_gmasks'][idx]
         | 
| 767 | 
            +
                        v_emb = outputs['pred_gtexts'][idx]
         | 
| 768 | 
            +
                        t_emb = gtext['class_emb']
         | 
| 769 | 
            +
             | 
| 770 | 
            +
                        t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 771 | 
            +
                        v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)            
         | 
| 772 | 
            +
             | 
| 773 | 
            +
                        temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
         | 
| 774 | 
            +
                        out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
         | 
| 775 | 
            +
                        
         | 
| 776 | 
            +
                        matched_id = out_prob.max(0)[1]
         | 
| 777 | 
            +
                        mask_pred_results += [pred_gmasks[matched_id,:,:]]
         | 
| 778 | 
            +
             | 
| 779 | 
            +
                    for i in range(len(mask_pred_results)):
         | 
| 780 | 
            +
                        # upsample masks
         | 
| 781 | 
            +
                        mask_pred_results[i] = F.interpolate(
         | 
| 782 | 
            +
                            mask_pred_results[i][None,],
         | 
| 783 | 
            +
                            size=(images.tensor.shape[-2], images.tensor.shape[-1]),
         | 
| 784 | 
            +
                            mode="bilinear",
         | 
| 785 | 
            +
                            align_corners=False,
         | 
| 786 | 
            +
                        )[0]
         | 
| 787 | 
            +
             | 
| 788 | 
            +
                    processed_results = []
         | 
| 789 | 
            +
                    for mask_pred_result, input_per_image, image_size in zip(
         | 
| 790 | 
            +
                        mask_pred_results, batched_inputs, images.image_sizes
         | 
| 791 | 
            +
                    ):
         | 
| 792 | 
            +
                        height = input_per_image.get("height", image_size[0])
         | 
| 793 | 
            +
                        width = input_per_image.get("width", image_size[1])
         | 
| 794 | 
            +
                        processed_results.append({})
         | 
| 795 | 
            +
             | 
| 796 | 
            +
                        mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
         | 
| 797 | 
            +
                            mask_pred_result, image_size, height, width
         | 
| 798 | 
            +
                        )
         | 
| 799 | 
            +
                        processed_results[-1]['grounding_mask'] = mask_pred_result
         | 
| 800 | 
            +
             | 
| 801 | 
            +
                        # compute bbox
         | 
| 802 | 
            +
                        # bbox = BitMasks(mask_pred_result > 0).get_bounding_boxes()
         | 
| 803 | 
            +
                        # bbox = BoxMode.convert(bbox.tensor, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
         | 
| 804 | 
            +
                        # processed_results[-1]['grounding_box'] = bbox
         | 
| 805 | 
            +
             | 
| 806 | 
            +
                    return processed_results
         | 
| 807 | 
            +
             | 
| 808 | 
            +
                def evaluate_grounding_sptial(self, batched_inputs, mode):
         | 
| 809 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 810 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 811 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 812 | 
            +
                    assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
         | 
| 813 | 
            +
             | 
| 814 | 
            +
                    extra = {}
         | 
| 815 | 
            +
                    dilation = 3
         | 
| 816 | 
            +
                    pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
         | 
| 817 | 
            +
                    pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
         | 
| 818 | 
            +
                    pos_masks = (F.conv2d(pos_masks.float(), self.dilation_kernel, padding=dilation//2) > 0).unbind(0)
         | 
| 819 | 
            +
             | 
| 820 | 
            +
                    neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
         | 
| 821 | 
            +
                    neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
         | 
| 822 | 
            +
             | 
| 823 | 
            +
                    mask_pred_results = []
         | 
| 824 | 
            +
                    for idx, batch_per_image in enumerate(batched_inputs):
         | 
| 825 | 
            +
                        grd_texts = batch_per_image['groundings']['texts']
         | 
| 826 | 
            +
                        grd_masks = []
         | 
| 827 | 
            +
                        for idx2, anno_text in enumerate(grd_texts):
         | 
| 828 | 
            +
                            extra.update({'spatial_query_pos_mask': [pos_masks[idx2]], 'spatial_query_neg_mask': [neg_masks[idx2]]})
         | 
| 829 | 
            +
             | 
| 830 | 
            +
                            gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
         | 
| 831 | 
            +
                            token_emb = gtext['token_emb']
         | 
| 832 | 
            +
                            tokens = gtext['tokens']
         | 
| 833 | 
            +
                        
         | 
| 834 | 
            +
                            grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
         | 
| 835 | 
            +
                            non_zero_query_mask = torch.zeros(grd_emb[:,None].shape[:-1], dtype=torch.bool, device=grd_emb.device)
         | 
| 836 | 
            +
                            extra['grounding_tokens'] = grd_emb[:,None]
         | 
| 837 | 
            +
                            extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
         | 
| 838 | 
            +
             | 
| 839 | 
            +
                            assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
         | 
| 840 | 
            +
                            features = self.backbone(images.tensor)
         | 
| 841 | 
            +
                            outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
         | 
| 842 | 
            +
                            
         | 
| 843 | 
            +
                            pred_gmasks = outputs['pred_gmasks'][idx]
         | 
| 844 | 
            +
                            v_emb = outputs['pred_gtexts'][idx]
         | 
| 845 | 
            +
                            t_emb = gtext['class_emb']
         | 
| 846 | 
            +
             | 
| 847 | 
            +
                            t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 848 | 
            +
                            v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)            
         | 
| 849 | 
            +
             | 
| 850 | 
            +
                            temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
         | 
| 851 | 
            +
                            out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
         | 
| 852 | 
            +
                            
         | 
| 853 | 
            +
                            matched_id = out_prob.max(0)[1]
         | 
| 854 | 
            +
                            grd_masks += [pred_gmasks[matched_id,:,:]]
         | 
| 855 | 
            +
                            # grd_masks += [outputs['prev_mask'][0]]
         | 
| 856 | 
            +
             | 
| 857 | 
            +
                        mask_pred_results += [torch.cat(grd_masks)]
         | 
| 858 | 
            +
             | 
| 859 | 
            +
                    # comment for multi object inference.
         | 
| 860 | 
            +
                    # mask_pred_results = []
         | 
| 861 | 
            +
                    # for idx, batch_per_image in enumerate(batched_inputs):
         | 
| 862 | 
            +
                    #     grd_texts = batch_per_image['groundings']['texts']
         | 
| 863 | 
            +
                    #     grd_texts = [x[0] for x in grd_texts]
         | 
| 864 | 
            +
             | 
| 865 | 
            +
                    #     gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
         | 
| 866 | 
            +
                    #     token_emb = gtext['token_emb']
         | 
| 867 | 
            +
                    #     tokens = gtext['tokens']
         | 
| 868 | 
            +
                    #     query_emb = token_emb[tokens['attention_mask'].bool()]
         | 
| 869 | 
            +
                    #     non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
         | 
| 870 | 
            +
             | 
| 871 | 
            +
                    #     extra['grounding_tokens'] = query_emb[:,None]
         | 
| 872 | 
            +
                    #     extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
         | 
| 873 | 
            +
             | 
| 874 | 
            +
                    #     features = self.backbone(images.tensor)
         | 
| 875 | 
            +
                    #     outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
         | 
| 876 | 
            +
             | 
| 877 | 
            +
                    #     pred_gmasks = outputs['pred_gmasks'][idx]
         | 
| 878 | 
            +
                    #     v_emb = outputs['pred_gtexts'][idx]
         | 
| 879 | 
            +
                    #     t_emb = gtext['class_emb']
         | 
| 880 | 
            +
             | 
| 881 | 
            +
                    #     t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 882 | 
            +
                    #     v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)            
         | 
| 883 | 
            +
             | 
| 884 | 
            +
                    #     temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
         | 
| 885 | 
            +
                    #     out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
         | 
| 886 | 
            +
                        
         | 
| 887 | 
            +
                    #     matched_id = out_prob.max(0)[1]
         | 
| 888 | 
            +
                    #     mask_pred_results += [pred_gmasks[matched_id,:,:]]
         | 
| 889 | 
            +
             | 
| 890 | 
            +
                    for i in range(len(mask_pred_results)):
         | 
| 891 | 
            +
                        # upsample masks
         | 
| 892 | 
            +
                        mask_pred_results[i] = F.interpolate(
         | 
| 893 | 
            +
                            mask_pred_results[i][None,],
         | 
| 894 | 
            +
                            size=(images.tensor.shape[-2], images.tensor.shape[-1]),
         | 
| 895 | 
            +
                            mode="bilinear",
         | 
| 896 | 
            +
                            align_corners=False,
         | 
| 897 | 
            +
                        )[0]
         | 
| 898 | 
            +
             | 
| 899 | 
            +
                    processed_results = []
         | 
| 900 | 
            +
                    for mask_pred_result, input_per_image, image_size in zip(
         | 
| 901 | 
            +
                        mask_pred_results, batched_inputs, images.image_sizes
         | 
| 902 | 
            +
                    ):
         | 
| 903 | 
            +
                        height = input_per_image.get("height", image_size[0])
         | 
| 904 | 
            +
                        width = input_per_image.get("width", image_size[1])
         | 
| 905 | 
            +
                        processed_results.append({})
         | 
| 906 | 
            +
             | 
| 907 | 
            +
                        mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
         | 
| 908 | 
            +
                            mask_pred_result, image_size, height, width
         | 
| 909 | 
            +
                        )
         | 
| 910 | 
            +
                        processed_results[-1]['grounding_mask'] = mask_pred_result
         | 
| 911 | 
            +
             | 
| 912 | 
            +
                    return processed_results
         | 
| 913 | 
            +
             | 
| 914 | 
            +
                def prepare_targets(self, batched_inputs, images):
         | 
| 915 | 
            +
                    h_pad, w_pad = images.tensor.shape[-2:]
         | 
| 916 | 
            +
                    new_targets = []
         | 
| 917 | 
            +
                    for idx, batch_per_image in enumerate(batched_inputs):            
         | 
| 918 | 
            +
                        target_dict = {}
         | 
| 919 | 
            +
                        if self.task_switch['mask']:
         | 
| 920 | 
            +
                            targets_per_image = batch_per_image['instances'].to(self.device)
         | 
| 921 | 
            +
                            # pad gt
         | 
| 922 | 
            +
                            gt_masks = targets_per_image.gt_masks.tensor
         | 
| 923 | 
            +
                            padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
         | 
| 924 | 
            +
                            padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
         | 
| 925 | 
            +
             | 
| 926 | 
            +
                            gt_boxes = targets_per_image.gt_boxes.tensor
         | 
| 927 | 
            +
                            ratio = torch.tensor([w_pad,h_pad,w_pad,h_pad]).to(gt_boxes.device)[None,:]
         | 
| 928 | 
            +
                            gt_boxes = gt_boxes / ratio
         | 
| 929 | 
            +
                            xc,yc,w,h = (gt_boxes[:,0] + gt_boxes[:,2])/2, (gt_boxes[:,1] + gt_boxes[:,3])/2, gt_boxes[:,2] - gt_boxes[:,0], gt_boxes[:,3] - gt_boxes[:,1]
         | 
| 930 | 
            +
                            gt_boxes = torch.stack([xc,yc,w,h]).permute(1,0)
         | 
| 931 | 
            +
             | 
| 932 | 
            +
                            target_dict.update({
         | 
| 933 | 
            +
                                    "labels": targets_per_image.gt_classes,
         | 
| 934 | 
            +
                                    "is_things": targets_per_image.is_things,
         | 
| 935 | 
            +
                                    "masks": padded_masks,
         | 
| 936 | 
            +
                                    "boxes": gt_boxes,
         | 
| 937 | 
            +
                                    })
         | 
| 938 | 
            +
             | 
| 939 | 
            +
                        if self.task_switch['spatial']:
         | 
| 940 | 
            +
                            # prepare targets for spatial query
         | 
| 941 | 
            +
                            target_dict['gt_spatial_masks'] = batch_per_image['spatial_query']['gt_masks']
         | 
| 942 | 
            +
             | 
| 943 | 
            +
                        if self.task_switch['grounding']:
         | 
| 944 | 
            +
                            grd_masks = batch_per_image['groundings']['masks']
         | 
| 945 | 
            +
                            grd_texts = batch_per_image['groundings']['texts']
         | 
| 946 | 
            +
                            grd_hash = batch_per_image['groundings']['hash']
         | 
| 947 | 
            +
                            grd_task = batch_per_image['groundings']['mode']
         | 
| 948 | 
            +
                            
         | 
| 949 | 
            +
                            if len(grd_masks) == 0:
         | 
| 950 | 
            +
                                padded_masks = None
         | 
| 951 | 
            +
                            else:
         | 
| 952 | 
            +
                                padded_masks = torch.zeros((grd_masks.shape[0], h_pad, w_pad), dtype=grd_masks.dtype, device=grd_masks.device)
         | 
| 953 | 
            +
                                padded_masks[:, : grd_masks.shape[1], : grd_masks.shape[2]] = grd_masks
         | 
| 954 | 
            +
             | 
| 955 | 
            +
                            gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
         | 
| 956 | 
            +
                            token_emb = gtext['token_emb']
         | 
| 957 | 
            +
                            tokens = gtext['tokens']
         | 
| 958 | 
            +
                            
         | 
| 959 | 
            +
                            unique_hash_id = np.unique(grd_hash, return_index=True)[1]
         | 
| 960 | 
            +
                            selected_mask = np.zeros(len(grd_hash)).astype(bool)
         | 
| 961 | 
            +
                            selected_mask[unique_hash_id] = True
         | 
| 962 | 
            +
             | 
| 963 | 
            +
                            selected_token_emb = token_emb[selected_mask]
         | 
| 964 | 
            +
                            selected_attn_mask = tokens['attention_mask'][selected_mask]
         | 
| 965 | 
            +
                            query_emb = selected_token_emb[selected_attn_mask.bool()]
         | 
| 966 | 
            +
                            
         | 
| 967 | 
            +
                            class_idx = tokens['attention_mask'].sum(dim=-1) - 1
         | 
| 968 | 
            +
                            class_idx = torch.stack((torch.arange(len(class_idx), device=class_idx.device), class_idx)).tolist()
         | 
| 969 | 
            +
                            class_emb = token_emb[class_idx]
         | 
| 970 | 
            +
                            
         | 
| 971 | 
            +
                            target_dict['grounding_masks'] = padded_masks
         | 
| 972 | 
            +
                            target_dict['grounding_query_embs'] = query_emb
         | 
| 973 | 
            +
                            target_dict['grounding_class_embs'] = class_emb
         | 
| 974 | 
            +
                            target_dict['grounding_hash'] = grd_hash
         | 
| 975 | 
            +
                            target_dict['grounding_task'] = grd_task
         | 
| 976 | 
            +
             | 
| 977 | 
            +
                        new_targets.append(target_dict)
         | 
| 978 | 
            +
                    return new_targets
         | 
| 979 | 
            +
             | 
| 980 | 
            +
                def prepare_next_spaital_mask(self, outputs, batched_inputs, mode='best'):
         | 
| 981 | 
            +
                    gt_masks = [batched_inputs[i]['spatial_query']['gt_masks'] for i in range(len(batched_inputs))]
         | 
| 982 | 
            +
                    gt_masks = Spatial_ImageList.from_tensors(gt_masks, self.size_divisibility).tensor
         | 
| 983 | 
            +
             | 
| 984 | 
            +
                    pred_masks = (F.interpolate(outputs['prev_mask'], size=gt_masks.shape[-2:], mode='bilinear', align_corners=False).sigmoid() > 0.5)
         | 
| 985 | 
            +
                    prev_masks = nn.utils.rnn.pad_sequence(outputs['spatial_query_pos_mask'], padding_value=False, batch_first=True) | \
         | 
| 986 | 
            +
                                    nn.utils.rnn.pad_sequence(outputs['spatial_query_neg_mask'], padding_value=False, batch_first=True)
         | 
| 987 | 
            +
             | 
| 988 | 
            +
                    fn = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks) # fn: False Negative, gt:1, pred:0, prev:0
         | 
| 989 | 
            +
                    fp = (~gt_masks & pred_masks) & (~prev_masks) # fp: False Positive, gt:0, pred:1, prev:0
         | 
| 990 | 
            +
             | 
| 991 | 
            +
                    # compute iou between gt and pred
         | 
| 992 | 
            +
                    iou = (gt_masks & pred_masks).sum(list(range(2,len(fn.shape)))) / ((gt_masks | pred_masks).sum(dim=list(range(2,len(fn.shape)))) + 1e-8)
         | 
| 993 | 
            +
                    fn_sum = fn.sum(dim=list(range(2,len(fn.shape))))
         | 
| 994 | 
            +
                    fp_sum = fp.sum(dim=list(range(2,len(fp.shape))))
         | 
| 995 | 
            +
             | 
| 996 | 
            +
                    is_postive = fn_sum > fp_sum
         | 
| 997 | 
            +
                    select_mask = torch.zeros_like(fn)
         | 
| 998 | 
            +
                    select_mask[is_postive] = fn[is_postive]
         | 
| 999 | 
            +
                    select_mask[~is_postive] = fp[~is_postive]
         | 
| 1000 | 
            +
                    # is_postive = torch.ones(len(fn_sum), device=torch.cuda.current_device()).bool()
         | 
| 1001 | 
            +
             | 
| 1002 | 
            +
                    # conv implementation
         | 
| 1003 | 
            +
                    bs,ns,h,w = select_mask.shape
         | 
| 1004 | 
            +
                    mask_dt = (distance_transform((~F.pad(select_mask, pad=(1, 1, 1, 1), mode='constant', value=0)).float())[:,:,1:-1,1:-1]).reshape(bs*ns,-1)
         | 
| 1005 | 
            +
                    if mode == 'best':
         | 
| 1006 | 
            +
                        max_xy_idx = torch.stack([torch.arange(bs*ns), mask_dt.max(dim=-1)[1].cpu()]).tolist()
         | 
| 1007 | 
            +
                    elif mode == 'best_random':
         | 
| 1008 | 
            +
                        max_xy_idx = torch.stack([torch.arange(bs*ns), torch.cat([(mask_dt[i] > 0).nonzero()[torch.randint(0, len((mask_dt[i] > 0).nonzero()), (1,))][0] for i in range(len(mask_dt))]).cpu()]).tolist()
         | 
| 1009 | 
            +
                    next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
         | 
| 1010 | 
            +
                    next_mask = next_mask.view(bs*ns,-1)
         | 
| 1011 | 
            +
                    next_mask[max_xy_idx] = True
         | 
| 1012 | 
            +
                    next_mask = next_mask.reshape((bs*ns,1,h,w)).float()
         | 
| 1013 | 
            +
                    dilation = 3
         | 
| 1014 | 
            +
                    next_mask = F.conv2d(next_mask, self.dilation_kernel, padding=dilation//2).reshape(bs,ns,h,w) > 0
         | 
| 1015 | 
            +
             | 
| 1016 | 
            +
                    # determine whether next mask is zero
         | 
| 1017 | 
            +
                    keep = (iou < 0.925)
         | 
| 1018 | 
            +
                    next_mask = next_mask & keep.view(bs,ns,1,1)
         | 
| 1019 | 
            +
             | 
| 1020 | 
            +
                    pos_mask = []
         | 
| 1021 | 
            +
                    neg_mask = []
         | 
| 1022 | 
            +
                    for idx, ip in enumerate(is_postive):
         | 
| 1023 | 
            +
                        mask_len = len(outputs['spatial_query_pos_mask'][idx])
         | 
| 1024 | 
            +
                        pos_mask += [outputs['spatial_query_pos_mask'][idx] | (next_mask[idx][:mask_len] & ip[:mask_len,None,None])]
         | 
| 1025 | 
            +
                        neg_mask += [outputs['spatial_query_neg_mask'][idx] | (next_mask[idx][:mask_len] & (~ip[:mask_len,None,None]))]
         | 
| 1026 | 
            +
             | 
| 1027 | 
            +
                    if 'false_positive_mask' in outputs:
         | 
| 1028 | 
            +
                        fp = outputs['false_positive_mask'] | fp
         | 
| 1029 | 
            +
                    return {'spatial_query_pos_mask': pos_mask, 'spatial_query_neg_mask': neg_mask, 'false_positive_mask': fp}
         | 
| 1030 | 
            +
             | 
| 1031 | 
            +
                def semantic_inference(self, mask_cls, mask_pred):
         | 
| 1032 | 
            +
                    mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
         | 
| 1033 | 
            +
                    mask_pred = mask_pred.sigmoid()
         | 
| 1034 | 
            +
                    semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
         | 
| 1035 | 
            +
                    return semseg
         | 
| 1036 | 
            +
             | 
| 1037 | 
            +
                def panoptic_inference(self, mask_cls, mask_pred):
         | 
| 1038 | 
            +
                    scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
         | 
| 1039 | 
            +
                    mask_pred = mask_pred.sigmoid()
         | 
| 1040 | 
            +
             | 
| 1041 | 
            +
                    keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
         | 
| 1042 | 
            +
                    cur_scores = scores[keep]
         | 
| 1043 | 
            +
                    cur_classes = labels[keep]
         | 
| 1044 | 
            +
                    cur_masks = mask_pred[keep]
         | 
| 1045 | 
            +
                    cur_mask_cls = mask_cls[keep]
         | 
| 1046 | 
            +
                    cur_mask_cls = cur_mask_cls[:, :-1]
         | 
| 1047 | 
            +
             | 
| 1048 | 
            +
                    cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
         | 
| 1049 | 
            +
             | 
| 1050 | 
            +
                    h, w = cur_masks.shape[-2:]
         | 
| 1051 | 
            +
                    panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
         | 
| 1052 | 
            +
                    segments_info = []
         | 
| 1053 | 
            +
             | 
| 1054 | 
            +
                    current_segment_id = 0
         | 
| 1055 | 
            +
             | 
| 1056 | 
            +
                    if cur_masks.shape[0] == 0:
         | 
| 1057 | 
            +
                        # We didn't detect any mask :(
         | 
| 1058 | 
            +
                        return panoptic_seg, segments_info
         | 
| 1059 | 
            +
                    else:
         | 
| 1060 | 
            +
                        # take argmax
         | 
| 1061 | 
            +
                        cur_mask_ids = cur_prob_masks.argmax(0)
         | 
| 1062 | 
            +
                        stuff_memory_list = {}
         | 
| 1063 | 
            +
                        for k in range(cur_classes.shape[0]):
         | 
| 1064 | 
            +
                            pred_class = cur_classes[k].item()
         | 
| 1065 | 
            +
                            isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values()
         | 
| 1066 | 
            +
                            mask_area = (cur_mask_ids == k).sum().item()
         | 
| 1067 | 
            +
                            original_area = (cur_masks[k] >= 0.5).sum().item()
         | 
| 1068 | 
            +
                            mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
         | 
| 1069 | 
            +
             | 
| 1070 | 
            +
                            if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
         | 
| 1071 | 
            +
                                if mask_area / original_area < self.overlap_threshold:
         | 
| 1072 | 
            +
                                    continue
         | 
| 1073 | 
            +
             | 
| 1074 | 
            +
                                # merge stuff regions
         | 
| 1075 | 
            +
                                if not isthing:
         | 
| 1076 | 
            +
                                    if int(pred_class) in stuff_memory_list.keys():
         | 
| 1077 | 
            +
                                        panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
         | 
| 1078 | 
            +
                                        continue
         | 
| 1079 | 
            +
                                    else:
         | 
| 1080 | 
            +
                                        stuff_memory_list[int(pred_class)] = current_segment_id + 1
         | 
| 1081 | 
            +
             | 
| 1082 | 
            +
                                current_segment_id += 1
         | 
| 1083 | 
            +
                                panoptic_seg[mask] = current_segment_id
         | 
| 1084 | 
            +
             | 
| 1085 | 
            +
                                segments_info.append(
         | 
| 1086 | 
            +
                                    {
         | 
| 1087 | 
            +
                                        "id": current_segment_id,
         | 
| 1088 | 
            +
                                        "isthing": bool(isthing),
         | 
| 1089 | 
            +
                                        "category_id": int(pred_class),
         | 
| 1090 | 
            +
                                    }
         | 
| 1091 | 
            +
                                )
         | 
| 1092 | 
            +
             | 
| 1093 | 
            +
                        return panoptic_seg, segments_info
         | 
| 1094 | 
            +
             | 
| 1095 | 
            +
                def instance_inference(self, mask_cls, mask_pred, box_pred):
         | 
| 1096 | 
            +
                    # mask_pred is already processed to have the same shape as original input
         | 
| 1097 | 
            +
                    image_size = mask_pred.shape[-2:]
         | 
| 1098 | 
            +
             | 
| 1099 | 
            +
                    # [Q, K]
         | 
| 1100 | 
            +
                    scores = F.softmax(mask_cls, dim=-1)[:, :-1]
         | 
| 1101 | 
            +
                    labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
         | 
| 1102 | 
            +
                    # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
         | 
| 1103 | 
            +
                    scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
         | 
| 1104 | 
            +
             | 
| 1105 | 
            +
                    labels_per_image = labels[topk_indices]
         | 
| 1106 | 
            +
                    topk_indices = (topk_indices // self.sem_seg_head.num_classes)
         | 
| 1107 | 
            +
                    # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
         | 
| 1108 | 
            +
                    mask_pred = mask_pred[topk_indices]
         | 
| 1109 | 
            +
                    if box_pred is not None:
         | 
| 1110 | 
            +
                        box_pred = box_pred[topk_indices]
         | 
| 1111 | 
            +
             | 
| 1112 | 
            +
                    # if this is panoptic segmentation, we only keep the "thing" classes
         | 
| 1113 | 
            +
                    if self.panoptic_on:
         | 
| 1114 | 
            +
                        keep = torch.zeros_like(scores_per_image).bool()
         | 
| 1115 | 
            +
                        for i, lab in enumerate(labels_per_image):
         | 
| 1116 | 
            +
                            keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values()
         | 
| 1117 | 
            +
             | 
| 1118 | 
            +
                        scores_per_image = scores_per_image[keep]
         | 
| 1119 | 
            +
                        labels_per_image = labels_per_image[keep]
         | 
| 1120 | 
            +
                        mask_pred = mask_pred[keep]
         | 
| 1121 | 
            +
             | 
| 1122 | 
            +
                        if box_pred is not None:
         | 
| 1123 | 
            +
                            box_pred = box_pred[keep]
         | 
| 1124 | 
            +
             | 
| 1125 | 
            +
                    result = Instances(image_size)
         | 
| 1126 | 
            +
                    # mask (before sigmoid)
         | 
| 1127 | 
            +
                    result.pred_masks = (mask_pred > 0).float()
         | 
| 1128 | 
            +
                    # result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
         | 
| 1129 | 
            +
                    # Uncomment the following to get boxes from masks (this is slow)
         | 
| 1130 | 
            +
             | 
| 1131 | 
            +
                    if box_pred is not None:
         | 
| 1132 | 
            +
                        result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
         | 
| 1133 | 
            +
                    else:
         | 
| 1134 | 
            +
                        result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
         | 
| 1135 | 
            +
             | 
| 1136 | 
            +
                    # calculate average mask prob
         | 
| 1137 | 
            +
                    mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
         | 
| 1138 | 
            +
                    result.scores = scores_per_image * mask_scores_per_image
         | 
| 1139 | 
            +
                    result.pred_classes = labels_per_image
         | 
| 1140 | 
            +
             | 
| 1141 | 
            +
                    return result
         | 
| 1142 | 
            +
             | 
| 1143 | 
            +
                def prepare_targets4query(self, targets, images, topk=5):
         | 
| 1144 | 
            +
                    h_pad, w_pad = images.tensor.shape[-2:]
         | 
| 1145 | 
            +
                    new_targets = []
         | 
| 1146 | 
            +
                    new_queries = []
         | 
| 1147 | 
            +
                    for targets_per_image in targets:
         | 
| 1148 | 
            +
                        # we randomly sample maximally topk concepts
         | 
| 1149 | 
            +
                        unique_target_classes = [k for k in set(targets_per_image.gt_classes.tolist())]
         | 
| 1150 | 
            +
                        selected_target_classes = random.sample(unique_target_classes, min(topk, len(unique_target_classes)))
         | 
| 1151 | 
            +
                        new_targets_per_image = []
         | 
| 1152 | 
            +
                        new_queries_per_image = []
         | 
| 1153 | 
            +
                        for clss in selected_target_classes:
         | 
| 1154 | 
            +
                            indices = (targets_per_image.gt_classes == clss).nonzero().view(-1)
         | 
| 1155 | 
            +
                            # pad gt
         | 
| 1156 | 
            +
                            gt_masks = targets_per_image.gt_masks[indices]
         | 
| 1157 | 
            +
                            padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
         | 
| 1158 | 
            +
                            padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
         | 
| 1159 | 
            +
             | 
| 1160 | 
            +
                            # convert class into concept name and then token seq
         | 
| 1161 | 
            +
                            self.sem_seg_head.predictor.lang_encoder.get_text_embeddings([BIOMED_CLASSES[clss]], name='grounding')
         | 
| 1162 | 
            +
                            query = getattr(self.sem_seg_head.predictor.lang_encoder, 'grounding_text_embeddings')
         | 
| 1163 | 
            +
             | 
| 1164 | 
            +
                            new_targets.append(
         | 
| 1165 | 
            +
                                {
         | 
| 1166 | 
            +
                                    "labels": targets_per_image.gt_classes[indices],
         | 
| 1167 | 
            +
                                    "masks": padded_masks,
         | 
| 1168 | 
            +
                                }
         | 
| 1169 | 
            +
                            )
         | 
| 1170 | 
            +
                            new_queries_per_image.append(query)
         | 
| 1171 | 
            +
                        new_queries.append(new_queries_per_image)
         | 
| 1172 | 
            +
             | 
| 1173 | 
            +
                    return new_targets, new_queries
         | 
| 1174 | 
            +
             | 
| 1175 | 
            +
             | 
| 1176 | 
            +
             | 
| 1177 | 
            +
            @register_model
         | 
| 1178 | 
            +
            def get_seem_model(cfg, **kwargs):
         | 
| 1179 | 
            +
                return GeneralizedSEEM(cfg)
         | 
    	
        modeling/architectures/xdecoder_model.py
    ADDED
    
    | @@ -0,0 +1,937 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # --------------------------------------------------------
         | 
| 2 | 
            +
            # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
         | 
| 3 | 
            +
            # Copyright (c) 2022 Microsoft
         | 
| 4 | 
            +
            # Licensed under The MIT License [see LICENSE for details]
         | 
| 5 | 
            +
            # Written by Xueyan Zou (xueyan@cs.wisc.edu), Ziyi Dou, Jianwei Yang
         | 
| 6 | 
            +
            # --------------------------------------------------------
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from typing import Tuple
         | 
| 9 | 
            +
            import random
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import torch
         | 
| 12 | 
            +
            from torch import nn
         | 
| 13 | 
            +
            from torch.nn import functional as F
         | 
| 14 | 
            +
            import numpy as np
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from timm.models.layers import trunc_normal_
         | 
| 17 | 
            +
            from nltk.stem.lancaster import LancasterStemmer
         | 
| 18 | 
            +
            from detectron2.structures import Boxes, ImageList, Instances, BitMasks, BoxMode
         | 
| 19 | 
            +
            from detectron2.utils.memory import retry_if_cuda_oom
         | 
| 20 | 
            +
            from detectron2.data import MetadataCatalog
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from .build import register_model
         | 
| 23 | 
            +
            from ..utils import configurable, get_class_names
         | 
| 24 | 
            +
            from ..vision.backbone import build_backbone, Backbone
         | 
| 25 | 
            +
            from ..body import build_xdecoder_head
         | 
| 26 | 
            +
            from ..modules import sem_seg_postprocess, SetCriterion, HungarianMatcher, bbox_postprocess
         | 
| 27 | 
            +
            from ..language import build_language_encoder
         | 
| 28 | 
            +
            from ..language.loss import vl_similarity, image_text_contrastive_loss_queue
         | 
| 29 | 
            +
            from utilities.prompt_engineering import prompt_engineering
         | 
| 30 | 
            +
            from utilities.constants import COCO_PANOPTIC_CLASSES
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            st = LancasterStemmer()
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            class GeneralizedXdecoder(nn.Module):
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                @configurable
         | 
| 38 | 
            +
                def __init__(
         | 
| 39 | 
            +
                    self,
         | 
| 40 | 
            +
                    *,
         | 
| 41 | 
            +
                    backbone: Backbone,
         | 
| 42 | 
            +
                    sem_seg_head: nn.Module,
         | 
| 43 | 
            +
                    criterion: nn.Module,
         | 
| 44 | 
            +
                    losses: dict,
         | 
| 45 | 
            +
                    num_queries: int,
         | 
| 46 | 
            +
                    object_mask_threshold: float,
         | 
| 47 | 
            +
                    overlap_threshold: float,
         | 
| 48 | 
            +
                    metadata,
         | 
| 49 | 
            +
                    task_switch: dict,
         | 
| 50 | 
            +
                    phrase_prob: float,
         | 
| 51 | 
            +
                    size_divisibility: int,
         | 
| 52 | 
            +
                    sem_seg_postprocess_before_inference: bool,
         | 
| 53 | 
            +
                    pixel_mean: Tuple[float],
         | 
| 54 | 
            +
                    pixel_std: Tuple[float],
         | 
| 55 | 
            +
                    # inference
         | 
| 56 | 
            +
                    semantic_on: bool,
         | 
| 57 | 
            +
                    panoptic_on: bool,
         | 
| 58 | 
            +
                    instance_on: bool,
         | 
| 59 | 
            +
                    test_topk_per_image: int,
         | 
| 60 | 
            +
                    train_dataset_name: str,
         | 
| 61 | 
            +
                    retrieval_emsemble: bool,
         | 
| 62 | 
            +
                    backbone_dim: int,
         | 
| 63 | 
            +
                    dim_proj: int,
         | 
| 64 | 
            +
                ):
         | 
| 65 | 
            +
                    """
         | 
| 66 | 
            +
                    Args:
         | 
| 67 | 
            +
                        backbone: a backbone module, must follow detectron2's backbone interface
         | 
| 68 | 
            +
                        sem_seg_head: a module that predicts semantic segmentation from backbone features
         | 
| 69 | 
            +
                        criterion: a module that defines the loss
         | 
| 70 | 
            +
                        num_queries: int, number of queries
         | 
| 71 | 
            +
                        object_mask_threshold: float, threshold to filter query based on classification score
         | 
| 72 | 
            +
                            for panoptic segmentation inference
         | 
| 73 | 
            +
                        overlap_threshold: overlap threshold used in general inference for panoptic segmentation
         | 
| 74 | 
            +
                        metadata: dataset meta, get `thing` and `stuff` category names for panoptic
         | 
| 75 | 
            +
                            segmentation inference
         | 
| 76 | 
            +
                        size_divisibility: Some backbones require the input height and width to be divisible by a
         | 
| 77 | 
            +
                            specific integer. We can use this to override such requirement.
         | 
| 78 | 
            +
                        sem_seg_postprocess_before_inference: whether to resize the prediction back
         | 
| 79 | 
            +
                            to original input size before semantic segmentation inference or after.
         | 
| 80 | 
            +
                            For high-resolution dataset like Mapillary, resizing predictions before
         | 
| 81 | 
            +
                            inference will cause OOM error.
         | 
| 82 | 
            +
                        pixel_mean, pixel_std: list or tuple with #channels element, representing
         | 
| 83 | 
            +
                            the per-channel mean and std to be used to normalize the input image
         | 
| 84 | 
            +
                        semantic_on: bool, whether to output semantic segmentation prediction
         | 
| 85 | 
            +
                        instance_on: bool, whether to output instance segmentation prediction
         | 
| 86 | 
            +
                        panoptic_on: bool, whether to output panoptic segmentation prediction
         | 
| 87 | 
            +
                        test_topk_per_image: int, instance segmentation parameter, keep topk instances per image
         | 
| 88 | 
            +
                    """
         | 
| 89 | 
            +
                    super().__init__()
         | 
| 90 | 
            +
                    self.backbone = backbone
         | 
| 91 | 
            +
                    self.sem_seg_head = sem_seg_head
         | 
| 92 | 
            +
                    self.criterion = criterion
         | 
| 93 | 
            +
                    self.losses = losses
         | 
| 94 | 
            +
                    self.num_queries = num_queries
         | 
| 95 | 
            +
                    self.overlap_threshold = overlap_threshold
         | 
| 96 | 
            +
                    self.object_mask_threshold = object_mask_threshold
         | 
| 97 | 
            +
                    self.metadata = metadata
         | 
| 98 | 
            +
                    if size_divisibility < 0:
         | 
| 99 | 
            +
                        # use backbone size_divisibility if not set
         | 
| 100 | 
            +
                        size_divisibility = self.backbone.size_divisibility
         | 
| 101 | 
            +
                    self.size_divisibility = size_divisibility
         | 
| 102 | 
            +
                    self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
         | 
| 103 | 
            +
                    self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
         | 
| 104 | 
            +
                    self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    # additional args
         | 
| 107 | 
            +
                    self.semantic_on = semantic_on
         | 
| 108 | 
            +
                    self.instance_on = instance_on
         | 
| 109 | 
            +
                    self.panoptic_on = panoptic_on
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    # caption argument
         | 
| 112 | 
            +
                    self.task_switch = task_switch
         | 
| 113 | 
            +
                    self.phrase_prob = phrase_prob
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    self.test_topk_per_image = test_topk_per_image
         | 
| 116 | 
            +
                    self.train_class_names = get_class_names(train_dataset_name)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    self.retrieval_emsemble = retrieval_emsemble
         | 
| 119 | 
            +
                    # backbone itc loss
         | 
| 120 | 
            +
                    if task_switch['retrieval'] and retrieval_emsemble:
         | 
| 121 | 
            +
                        self.backbone_proj = nn.Parameter(torch.empty(backbone_dim, dim_proj))
         | 
| 122 | 
            +
                        trunc_normal_(self.backbone_proj, std=.02)
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    if not self.semantic_on:
         | 
| 125 | 
            +
                        assert self.sem_seg_postprocess_before_inference
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                @classmethod
         | 
| 128 | 
            +
                def from_config(cls, cfg):
         | 
| 129 | 
            +
                    enc_cfg = cfg['MODEL']['ENCODER']
         | 
| 130 | 
            +
                    dec_cfg = cfg['MODEL']['DECODER']
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    # Loss parameters:
         | 
| 133 | 
            +
                    deep_supervision = dec_cfg['DEEP_SUPERVISION']
         | 
| 134 | 
            +
                    no_object_weight = dec_cfg['NO_OBJECT_WEIGHT']
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    # loss weights, switcher for task, and top layers to compute loss
         | 
| 137 | 
            +
                    loss_weights = {'mask': {'ce': dec_cfg['CLASS_WEIGHT'], 'dice': dec_cfg['DICE_WEIGHT'], 'bce': dec_cfg['MASK_WEIGHT']},
         | 
| 138 | 
            +
                                    'bbox': {'l1': dec_cfg['BBOX_WEIGHT'], 'giou': dec_cfg['GIOU_WEIGHT']},
         | 
| 139 | 
            +
                                    'caption': dec_cfg['CAPTION_WEIGHT'],
         | 
| 140 | 
            +
                                    'captioning': dec_cfg['CAPTIONING_WEIGHT'], 
         | 
| 141 | 
            +
                                    'retrieval': {'decoder': dec_cfg['RETRIEVAL_WEIGHT'], 'backbone': dec_cfg['BACKBONER_WEIGHT']},
         | 
| 142 | 
            +
                                    'grounding': {'ce': dec_cfg['GCLASS_WEIGHT'], 'dice': dec_cfg['GDICE_WEIGHT'], 'bce': dec_cfg['GMASK_WEIGHT']}}
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    task_switch = {'bbox': dec_cfg.get('DETECTION', False),
         | 
| 145 | 
            +
                                   'mask': dec_cfg.get('MASK', True),
         | 
| 146 | 
            +
                                   'caption': dec_cfg['CAPTION'].get('ENABLED', False),
         | 
| 147 | 
            +
                                   'captioning': dec_cfg['CAPTIONING'].get('ENABLED', False),
         | 
| 148 | 
            +
                                   'retrieval': dec_cfg['RETRIEVAL'].get('ENABLED', False),
         | 
| 149 | 
            +
                                   'grounding': dec_cfg['GROUNDING'].get('ENABLED', False)}
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    top_x_layers = {'mask': dec_cfg.get('TOP_MASK_LAYERS', 10),
         | 
| 152 | 
            +
                                    'caption': dec_cfg.get('TOP_CAPTION_LAYERS', 10), 
         | 
| 153 | 
            +
                                    'captioning': dec_cfg.get('TOP_CAPTIONING_LAYERS', 10),
         | 
| 154 | 
            +
                                    'retrieval': dec_cfg.get('TOP_RETRIEVAL_LAYERS', 10),
         | 
| 155 | 
            +
                                    'grounding': dec_cfg.get('TOP_GROUNDING_LAYERS', 10),}
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    # build model
         | 
| 158 | 
            +
                    extra = {'task_switch': task_switch}
         | 
| 159 | 
            +
                    backbone = build_backbone(cfg)
         | 
| 160 | 
            +
                    lang_encoder = build_language_encoder(cfg)        
         | 
| 161 | 
            +
                    sem_seg_head = build_xdecoder_head(cfg, backbone.output_shape(), lang_encoder, extra)
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    # building criterion
         | 
| 164 | 
            +
                    matcher = HungarianMatcher(
         | 
| 165 | 
            +
                        cost_class=loss_weights['mask']['ce'],
         | 
| 166 | 
            +
                        cost_mask=loss_weights['mask']['bce'],
         | 
| 167 | 
            +
                        cost_dice=loss_weights['mask']['dice'],
         | 
| 168 | 
            +
                        num_points=dec_cfg['TRAIN_NUM_POINTS'],
         | 
| 169 | 
            +
                    )
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    # init weight dict and criterion loss functions.
         | 
| 172 | 
            +
                    losses = {'seg': [], 'vlp': []}
         | 
| 173 | 
            +
                    if task_switch['mask']:
         | 
| 174 | 
            +
                        losses['seg'] += ["labels", "masks"]
         | 
| 175 | 
            +
                    if task_switch['caption']:
         | 
| 176 | 
            +
                        losses['seg'] += ["captions"]
         | 
| 177 | 
            +
                    if task_switch['grounding']:
         | 
| 178 | 
            +
                        losses['seg'] += ["groundings"]
         | 
| 179 | 
            +
                    if task_switch['captioning']:
         | 
| 180 | 
            +
                        losses['vlp'] += ["captionings"]
         | 
| 181 | 
            +
                    if task_switch['retrieval']:
         | 
| 182 | 
            +
                        losses['vlp'] += ["retrievals"]
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    weight_dict = {}
         | 
| 185 | 
            +
                    for key, turn_on in task_switch.items():
         | 
| 186 | 
            +
                        if turn_on:
         | 
| 187 | 
            +
                            if isinstance(loss_weights[key], dict):
         | 
| 188 | 
            +
                                # HACK it should support bbox in the future
         | 
| 189 | 
            +
                                for key_, weight in loss_weights[key].items():
         | 
| 190 | 
            +
                                    weight_dict["loss_{}_{}_0".format(key, key_)] = weight # NOTE: hard code for segmentation that has multiple loss
         | 
| 191 | 
            +
                            else:
         | 
| 192 | 
            +
                                weight_dict["loss_{}_0".format(key)] = loss_weights[key]
         | 
| 193 | 
            +
                    
         | 
| 194 | 
            +
                    # generate full weight dict and remove not computed layers. 
         | 
| 195 | 
            +
                    if deep_supervision:
         | 
| 196 | 
            +
                        dec_layers = dec_cfg['DEC_LAYERS']
         | 
| 197 | 
            +
                        aux_weight_dict = {}
         | 
| 198 | 
            +
                        for i in range(dec_layers - 1):
         | 
| 199 | 
            +
                            for k, v in weight_dict.items():
         | 
| 200 | 
            +
                                if (i+1) > (top_x_layers[k.split('_')[1]] - 1):
         | 
| 201 | 
            +
                                    continue
         | 
| 202 | 
            +
                                aux_weight_dict.update({k.replace('_0', f"_{i+1}"): v})
         | 
| 203 | 
            +
                        weight_dict.update(aux_weight_dict)
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    grd_weight = {'text': dec_cfg['GROUNDING']['TEXT_WEIGHT'], 'class': dec_cfg['GROUNDING']['CLASS_WEIGHT']}
         | 
| 206 | 
            +
                    # generate critenrion for loss function.
         | 
| 207 | 
            +
                    criterion = SetCriterion(
         | 
| 208 | 
            +
                        sem_seg_head.num_classes,
         | 
| 209 | 
            +
                        matcher=matcher,
         | 
| 210 | 
            +
                        weight_dict=weight_dict,
         | 
| 211 | 
            +
                        top_x_layers=top_x_layers,
         | 
| 212 | 
            +
                        eos_coef=no_object_weight,
         | 
| 213 | 
            +
                        losses=[],
         | 
| 214 | 
            +
                        num_points=dec_cfg['TRAIN_NUM_POINTS'],
         | 
| 215 | 
            +
                        oversample_ratio=dec_cfg['OVERSAMPLE_RATIO'],
         | 
| 216 | 
            +
                        importance_sample_ratio=dec_cfg['IMPORTANCE_SAMPLE_RATIO'],
         | 
| 217 | 
            +
                        grounding_weight=grd_weight,
         | 
| 218 | 
            +
                    )
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    # extra logistic
         | 
| 221 | 
            +
                    train_dataset_name = cfg['DATASETS']['TRAIN'][0] # HACK for only one training set.
         | 
| 222 | 
            +
                    phrase_prob = dec_cfg['CAPTION'].get('PHRASE_PROB', 0.5)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    return {
         | 
| 225 | 
            +
                        "backbone": backbone,
         | 
| 226 | 
            +
                        "sem_seg_head": sem_seg_head,
         | 
| 227 | 
            +
                        "criterion": criterion,
         | 
| 228 | 
            +
                        "losses": losses,
         | 
| 229 | 
            +
                        "num_queries": dec_cfg['NUM_OBJECT_QUERIES'],
         | 
| 230 | 
            +
                        "object_mask_threshold": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'],
         | 
| 231 | 
            +
                        "overlap_threshold": dec_cfg['TEST']['OVERLAP_THRESHOLD'],
         | 
| 232 | 
            +
                        "metadata": MetadataCatalog.get(cfg['DATASETS']['TRAIN'][0]),
         | 
| 233 | 
            +
                        "size_divisibility": dec_cfg['SIZE_DIVISIBILITY'],
         | 
| 234 | 
            +
                        "sem_seg_postprocess_before_inference": (
         | 
| 235 | 
            +
                            dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE']
         | 
| 236 | 
            +
                            or dec_cfg['TEST']['PANOPTIC_ON']
         | 
| 237 | 
            +
                            or dec_cfg['TEST']['INSTANCE_ON']
         | 
| 238 | 
            +
                        ),
         | 
| 239 | 
            +
                        "pixel_mean": cfg['INPUT']['PIXEL_MEAN'],
         | 
| 240 | 
            +
                        "pixel_std": cfg['INPUT']['PIXEL_STD'],
         | 
| 241 | 
            +
                        "task_switch": task_switch,
         | 
| 242 | 
            +
                        "phrase_prob": phrase_prob,
         | 
| 243 | 
            +
                        # inference
         | 
| 244 | 
            +
                        "semantic_on": dec_cfg['TEST']['SEMANTIC_ON'],
         | 
| 245 | 
            +
                        "instance_on": dec_cfg['TEST']['INSTANCE_ON'],
         | 
| 246 | 
            +
                        "panoptic_on": dec_cfg['TEST']['PANOPTIC_ON'],
         | 
| 247 | 
            +
                        "test_topk_per_image": cfg['COCO']['TEST']['DETECTIONS_PER_IMAGE'],
         | 
| 248 | 
            +
                        "train_dataset_name": train_dataset_name,
         | 
| 249 | 
            +
                        "retrieval_emsemble": dec_cfg['RETRIEVAL']['ENSEMBLE'],
         | 
| 250 | 
            +
                        "backbone_dim": cfg['MODEL']['BACKBONE_DIM'],
         | 
| 251 | 
            +
                        "dim_proj": cfg['MODEL']['DIM_PROJ'],
         | 
| 252 | 
            +
                    }
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                @property
         | 
| 255 | 
            +
                def device(self):
         | 
| 256 | 
            +
                    return self.pixel_mean.device
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                def forward(self, batched_inputs, mode=None):
         | 
| 259 | 
            +
                    """
         | 
| 260 | 
            +
                    Args:
         | 
| 261 | 
            +
                        batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
         | 
| 262 | 
            +
                            Each item in the list contains the inputs for one image.
         | 
| 263 | 
            +
                            For now, each item in the list is a dict that contains:
         | 
| 264 | 
            +
                               * "image": Tensor, image in (C, H, W) format.
         | 
| 265 | 
            +
                               * "instances": per-region ground truth
         | 
| 266 | 
            +
                               * Other information that's included in the original dicts, such as:
         | 
| 267 | 
            +
                                 "height", "width" (int): the output resolution of the model (may be different
         | 
| 268 | 
            +
                                 from input resolution), used in inference.
         | 
| 269 | 
            +
                    Returns:
         | 
| 270 | 
            +
                        list[dict]:
         | 
| 271 | 
            +
                            each dict has the results for one image. The dict contains the following keys:
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                            * "sem_seg":
         | 
| 274 | 
            +
                                A Tensor that represents the
         | 
| 275 | 
            +
                                per-pixel segmentation prediced by the head.
         | 
| 276 | 
            +
                                The prediction has shape KxHxW that represents the logits of
         | 
| 277 | 
            +
                                each class for each pixel.
         | 
| 278 | 
            +
                            * "panoptic_seg":
         | 
| 279 | 
            +
                                A tuple that represent panoptic output
         | 
| 280 | 
            +
                                panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
         | 
| 281 | 
            +
                                segments_info (list[dict]): Describe each segment in `panoptic_seg`.
         | 
| 282 | 
            +
                                    Each dict contains keys "id", "category_id", "isthing".
         | 
| 283 | 
            +
                    """
         | 
| 284 | 
            +
                    if self.training:
         | 
| 285 | 
            +
                        losses = {}
         | 
| 286 | 
            +
                        if self.task_switch['mask']:
         | 
| 287 | 
            +
                            losses_seg = self.forward_seg(batched_inputs['coco'])
         | 
| 288 | 
            +
                            losses.update(losses_seg)
         | 
| 289 | 
            +
                        if self.task_switch['retrieval'] or self.task_switch['captioning']:
         | 
| 290 | 
            +
                            losses_vlp = self.forward_vlp(batched_inputs['vlp'])
         | 
| 291 | 
            +
                            losses.update(losses_vlp)
         | 
| 292 | 
            +
                        for k in list(losses.keys()):
         | 
| 293 | 
            +
                            if k in self.criterion.weight_dict:
         | 
| 294 | 
            +
                                losses[k] *= self.criterion.weight_dict[k]
         | 
| 295 | 
            +
                            else: # remove this loss if not specified in `weight_dict`
         | 
| 296 | 
            +
                                losses.pop(k)
         | 
| 297 | 
            +
                        return losses
         | 
| 298 | 
            +
                    else:
         | 
| 299 | 
            +
                        if mode == 'retrieval':
         | 
| 300 | 
            +
                            return self.evaluate_retrieval(batched_inputs)
         | 
| 301 | 
            +
                        elif mode == 'captioning':
         | 
| 302 | 
            +
                            return self.evaluate_captioning(batched_inputs)
         | 
| 303 | 
            +
                        elif mode == 'classification':
         | 
| 304 | 
            +
                            return self.evaluate_classification(batched_inputs)
         | 
| 305 | 
            +
                        elif mode == 'grounding_refcoco':
         | 
| 306 | 
            +
                            return self.evaluate_grounding(batched_inputs, mode)
         | 
| 307 | 
            +
                        else:
         | 
| 308 | 
            +
                            return self.evaluate(batched_inputs)
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                    
         | 
| 311 | 
            +
                def forward_seg(self, batched_inputs):
         | 
| 312 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 313 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 314 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                    self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(self.train_class_names, is_eval=False)
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                    extra = {}
         | 
| 319 | 
            +
                    # mask classification target
         | 
| 320 | 
            +
                    if "instances" in batched_inputs[0]:
         | 
| 321 | 
            +
                        # input bounding box is checked to be correct.
         | 
| 322 | 
            +
                        targets = self.prepare_targets(batched_inputs, images)
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                        if self.task_switch['grounding']:
         | 
| 325 | 
            +
                            grounding_tokens = [x['grounding_query_embs'] for x in targets] # need to pad for more than one grounding token
         | 
| 326 | 
            +
                            grounding_tokens = nn.utils.rnn.pad_sequence(grounding_tokens)
         | 
| 327 | 
            +
                            extra['grounding_tokens'] = grounding_tokens
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                    features = self.backbone(images.tensor)
         | 
| 330 | 
            +
                    outputs = self.sem_seg_head(features, extra=extra)
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                    _outputs = {}
         | 
| 333 | 
            +
                    for key, value in outputs.items():
         | 
| 334 | 
            +
                        if key == 'pred_logits':
         | 
| 335 | 
            +
                            _outputs[key] = value[:,:self.num_queries-1]
         | 
| 336 | 
            +
                        elif key == 'pred_masks':
         | 
| 337 | 
            +
                            _outputs[key] = value[:,:self.num_queries-1]
         | 
| 338 | 
            +
                            if self.task_switch['grounding']:
         | 
| 339 | 
            +
                                _outputs['pred_gmasks'] = value[:,self.num_queries:2*self.num_queries-1]
         | 
| 340 | 
            +
                        elif key == 'pred_captions':
         | 
| 341 | 
            +
                            _outputs[key] = value[:,:self.num_queries-1]
         | 
| 342 | 
            +
                            if self.task_switch['grounding']:
         | 
| 343 | 
            +
                                _outputs['pred_gtexts'] = value[:,self.num_queries:2*self.num_queries-1]
         | 
| 344 | 
            +
                        elif key == 'aux_outputs':
         | 
| 345 | 
            +
                            _outputs[key] = []
         | 
| 346 | 
            +
                            for i in range(len(value)):
         | 
| 347 | 
            +
                                _outputs[key] += [{}]
         | 
| 348 | 
            +
                                for _key, _value in value[i].items():
         | 
| 349 | 
            +
                                    if _key == 'pred_logits':
         | 
| 350 | 
            +
                                        _outputs[key][i][_key] = _value[:,:self.num_queries-1]
         | 
| 351 | 
            +
                                    elif _key == 'pred_masks':
         | 
| 352 | 
            +
                                        _outputs[key][i][_key] = _value[:,:self.num_queries-1]
         | 
| 353 | 
            +
                                        if self.task_switch['grounding']:
         | 
| 354 | 
            +
                                            _outputs[key][i]['pred_gmasks'] = _value[:,self.num_queries:2*self.num_queries-1]
         | 
| 355 | 
            +
                                    elif _key == 'pred_captions':
         | 
| 356 | 
            +
                                        _outputs[key][i][_key] = _value[:,:self.num_queries-1]
         | 
| 357 | 
            +
                                        if self.task_switch['grounding']:
         | 
| 358 | 
            +
                                            _outputs[key][i]['pred_gtexts'] = _value[:,self.num_queries:2*self.num_queries-1]        
         | 
| 359 | 
            +
                    outputs = _outputs
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                    extra = {'lang_logit': self.sem_seg_head.predictor.lang_encoder.logit_scale,
         | 
| 362 | 
            +
                             'class_embeddings': getattr(self.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('default'))}
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                    # bipartite matching-based loss
         | 
| 365 | 
            +
                    self.criterion.losses = self.losses['seg'] # seg criterion losses
         | 
| 366 | 
            +
                    losses = self.criterion(outputs, targets, extra)
         | 
| 367 | 
            +
             | 
| 368 | 
            +
                    del outputs
         | 
| 369 | 
            +
                    del _outputs
         | 
| 370 | 
            +
                    return losses
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                def forward_vlp(self, batched_inputs):
         | 
| 373 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 374 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 375 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 376 | 
            +
                    targets_vlp = self.prepare_vlp_targets(batched_inputs, images.tensor.device)
         | 
| 377 | 
            +
             | 
| 378 | 
            +
                    extra = {"token_embedding": self.sem_seg_head.predictor.lang_encoder.lang_encoder.token_embedding,
         | 
| 379 | 
            +
                             "lang_encoder": self.sem_seg_head.predictor.lang_encoder,
         | 
| 380 | 
            +
                             "training": self.training}
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                    features = self.backbone(images.tensor)
         | 
| 383 | 
            +
                    outputs = self.sem_seg_head(features, target_queries=None, target_vlp=targets_vlp, task='vlp', extra=extra)
         | 
| 384 | 
            +
             | 
| 385 | 
            +
                    for key, value in outputs.items():
         | 
| 386 | 
            +
                        if key == 'pred_captionings':
         | 
| 387 | 
            +
                            outputs[key] = value
         | 
| 388 | 
            +
                        elif key == 'pred_captions':
         | 
| 389 | 
            +
                            # outputs[key] = value[:,-1:]
         | 
| 390 | 
            +
                            outputs[key] = value
         | 
| 391 | 
            +
                        elif key == 'aux_outputs':
         | 
| 392 | 
            +
                            outputs[key] = []
         | 
| 393 | 
            +
                            for i in range(len(value)):
         | 
| 394 | 
            +
                                outputs[key] += [{}]
         | 
| 395 | 
            +
                                for _key, _value in value[i].items():
         | 
| 396 | 
            +
                                    if _key == 'pred_captions':
         | 
| 397 | 
            +
                                        # outputs[key][i][_key] = _value[:,-1:]
         | 
| 398 | 
            +
                                        outputs[key][i][_key] = _value
         | 
| 399 | 
            +
                                    elif _key == 'pred_captionings':
         | 
| 400 | 
            +
                                        outputs[key][i][_key] = _value
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                    self.criterion.losses = self.losses['vlp'] # seg criterion losses
         | 
| 403 | 
            +
                    losses = self.criterion.forward_vlp(outputs, targets_vlp, extra)
         | 
| 404 | 
            +
                    del outputs
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                    if self.task_switch['retrieval'] and self.retrieval_emsemble:
         | 
| 407 | 
            +
                        # compute backbone vlp.
         | 
| 408 | 
            +
                        v_emb = features['res5']
         | 
| 409 | 
            +
                        bs,nc,_,_ = v_emb.shape
         | 
| 410 | 
            +
                        v_emb = v_emb.reshape(bs,nc,-1)
         | 
| 411 | 
            +
                        v_emb = F.adaptive_avg_pool1d(v_emb, 1).reshape(bs,nc) @ self.backbone_proj
         | 
| 412 | 
            +
                        t_emb = torch.cat([x['caption_proj'] for x in targets_vlp], dim=0)
         | 
| 413 | 
            +
                        loss_contrast = image_text_contrastive_loss_queue(v_emb, t_emb, self.sem_seg_head.predictor.lang_encoder, None)
         | 
| 414 | 
            +
                        losses['loss_retrieval_backbone_0'] = loss_contrast
         | 
| 415 | 
            +
                    return losses
         | 
| 416 | 
            +
             | 
| 417 | 
            +
                def evaluate(self, batched_inputs):
         | 
| 418 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 419 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 420 | 
            +
                    
         | 
| 421 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 422 | 
            +
                    img_bs = images.tensor.shape[0]
         | 
| 423 | 
            +
             | 
| 424 | 
            +
                    targets = targets_grounding = queries_grounding = None
         | 
| 425 | 
            +
                    features = self.backbone(images.tensor)
         | 
| 426 | 
            +
                    outputs = self.sem_seg_head(features, target_queries=queries_grounding)
         | 
| 427 | 
            +
             | 
| 428 | 
            +
                    mask_cls_results = outputs["pred_logits"]
         | 
| 429 | 
            +
                    mask_pred_results = outputs["pred_masks"]
         | 
| 430 | 
            +
                    box_pred_results = outputs["pred_boxes"] if self.task_switch['bbox'] else [None for i in range(len(mask_pred_results))]
         | 
| 431 | 
            +
                    caption_pred_results = outputs["pred_captions"] if self.task_switch['caption'] else [None for i in range(len(mask_pred_results))]
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                    # upsample masks
         | 
| 434 | 
            +
                    mask_pred_results = F.interpolate(
         | 
| 435 | 
            +
                        mask_pred_results,
         | 
| 436 | 
            +
                        size=(images.tensor.shape[-2], images.tensor.shape[-1]),
         | 
| 437 | 
            +
                        mode="bicubic",
         | 
| 438 | 
            +
                        align_corners=False,
         | 
| 439 | 
            +
                        antialias=True
         | 
| 440 | 
            +
                    )
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                    input_size = mask_pred_results.shape[-2:]
         | 
| 443 | 
            +
                    keep_sem_bgd = self.metadata.keep_sem_bgd if hasattr(self.metadata, 'keep_sem_bgd') else False
         | 
| 444 | 
            +
                    del outputs
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                    processed_results = []
         | 
| 447 | 
            +
                    for mask_cls_result, mask_pred_result, box_pred_result, caption_pred_result, input_per_image, image_size in zip(
         | 
| 448 | 
            +
                        mask_cls_results, mask_pred_results, box_pred_results, caption_pred_results, batched_inputs, images.image_sizes
         | 
| 449 | 
            +
                    ):
         | 
| 450 | 
            +
                        height = input_per_image.get("height", image_size[0])
         | 
| 451 | 
            +
                        width = input_per_image.get("width", image_size[1])
         | 
| 452 | 
            +
                        processed_results.append({})
         | 
| 453 | 
            +
             | 
| 454 | 
            +
                        if self.sem_seg_postprocess_before_inference:
         | 
| 455 | 
            +
                            mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
         | 
| 456 | 
            +
                                mask_pred_result, image_size, height, width
         | 
| 457 | 
            +
                            )
         | 
| 458 | 
            +
                            mask_cls_result = mask_cls_result.to(mask_pred_result)
         | 
| 459 | 
            +
             | 
| 460 | 
            +
                        # semantic segmentation inference
         | 
| 461 | 
            +
                        if self.semantic_on:
         | 
| 462 | 
            +
                            r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result, keep_sem_bgd)
         | 
| 463 | 
            +
                            if not self.sem_seg_postprocess_before_inference:
         | 
| 464 | 
            +
                                r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width)
         | 
| 465 | 
            +
                            processed_results[-1]["sem_seg"] = r
         | 
| 466 | 
            +
             | 
| 467 | 
            +
                        # panoptic segmentation inference
         | 
| 468 | 
            +
                        if self.panoptic_on:
         | 
| 469 | 
            +
                            panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result)
         | 
| 470 | 
            +
                            processed_results[-1]["panoptic_seg"] = panoptic_r
         | 
| 471 | 
            +
                        
         | 
| 472 | 
            +
                        # instance segmentation inference
         | 
| 473 | 
            +
                        if self.instance_on:
         | 
| 474 | 
            +
                            if self.task_switch['bbox']:
         | 
| 475 | 
            +
                                box_pred_result = bbox_postprocess(box_pred_result, input_size, image_size, height, width)
         | 
| 476 | 
            +
                            instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result, box_pred_result)
         | 
| 477 | 
            +
                            processed_results[-1]["instances"] = instance_r
         | 
| 478 | 
            +
                        if self.task_switch['caption']:
         | 
| 479 | 
            +
                            processed_results[-1]["captions"] = caption_pred_result
         | 
| 480 | 
            +
                            processed_results[-1]["masks"] = mask_pred_result
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                    return processed_results
         | 
| 483 | 
            +
             | 
| 484 | 
            +
                def evaluate_retrieval(self, batched_inputs):
         | 
| 485 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 486 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 487 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 488 | 
            +
                    img_bs = images.tensor.shape[0]
         | 
| 489 | 
            +
                    
         | 
| 490 | 
            +
                    targets = targets_grounding = queries_grounding = None
         | 
| 491 | 
            +
                    features = self.backbone(images.tensor)
         | 
| 492 | 
            +
                    outputs = self.sem_seg_head(features, target_queries=queries_grounding)
         | 
| 493 | 
            +
                    v_emb_it = outputs['pred_captions'][:,-1]
         | 
| 494 | 
            +
             | 
| 495 | 
            +
                    # compute backbone score
         | 
| 496 | 
            +
                    if self.task_switch['retrieval'] and self.retrieval_emsemble:
         | 
| 497 | 
            +
                        _v_emb_it = features['res5']
         | 
| 498 | 
            +
                        bs,nc,_,_ = _v_emb_it.shape
         | 
| 499 | 
            +
                        _v_emb_it = _v_emb_it.reshape(bs,nc,-1)
         | 
| 500 | 
            +
                        _v_emb_it = F.adaptive_avg_pool1d(_v_emb_it, 1).reshape(bs,nc) @ self.backbone_proj
         | 
| 501 | 
            +
             | 
| 502 | 
            +
                    processed_results = []
         | 
| 503 | 
            +
                    for idx, batch_data in enumerate(batched_inputs):
         | 
| 504 | 
            +
                        caption_ids = []
         | 
| 505 | 
            +
                        t_emb_its = []
         | 
| 506 | 
            +
                        processed_results.append({})
         | 
| 507 | 
            +
                        for caption in batch_data['captions']:
         | 
| 508 | 
            +
                            lang_results = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(caption)
         | 
| 509 | 
            +
                            t_emb_it = lang_results['class_emb']
         | 
| 510 | 
            +
                            caption_ids.append(batch_data['image_id'])
         | 
| 511 | 
            +
                            t_emb_its.append(t_emb_it)
         | 
| 512 | 
            +
             | 
| 513 | 
            +
                        t_emb_it = torch.cat(t_emb_its, dim=0)
         | 
| 514 | 
            +
             | 
| 515 | 
            +
                        image_embeds = [v_emb_it[idx].unsqueeze(0)]
         | 
| 516 | 
            +
                        if self.task_switch['retrieval'] and self.retrieval_emsemble:
         | 
| 517 | 
            +
                            image_embeds += [_v_emb_it[idx].unsqueeze(0)]
         | 
| 518 | 
            +
                        caption_results = {
         | 
| 519 | 
            +
                                'image_embeds': image_embeds,
         | 
| 520 | 
            +
                                'text_embeds': t_emb_it,
         | 
| 521 | 
            +
                                'caption_ids': caption_ids,
         | 
| 522 | 
            +
                                'image_ids': batch_data['image_id'],
         | 
| 523 | 
            +
                            }
         | 
| 524 | 
            +
                        processed_results[-1]["caption"] = caption_results            
         | 
| 525 | 
            +
             | 
| 526 | 
            +
                    del features
         | 
| 527 | 
            +
                    return processed_results
         | 
| 528 | 
            +
             | 
| 529 | 
            +
                def evaluate_captioning(self, batched_inputs):
         | 
| 530 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 531 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 532 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 533 | 
            +
                    img_bs = images.tensor.shape[0]
         | 
| 534 | 
            +
             | 
| 535 | 
            +
                    if not hasattr(self, 'start_token'):
         | 
| 536 | 
            +
                        self.start_token = torch.tensor([[49406]*77], device=self.device)
         | 
| 537 | 
            +
                    
         | 
| 538 | 
            +
                    targets = targets_grounding = queries_grounding = None
         | 
| 539 | 
            +
                    features = self.backbone(images.tensor)
         | 
| 540 | 
            +
             | 
| 541 | 
            +
                    captioning_mask = None
         | 
| 542 | 
            +
                    if 'captioning_mask' in batched_inputs[-1]:
         | 
| 543 | 
            +
                        captioning_mask = torch.cat([x['captioning_mask'] for x in batched_inputs])
         | 
| 544 | 
            +
             | 
| 545 | 
            +
                    outputs = self.sem_seg_head(features, target_queries=queries_grounding, task='captioning_infer', extra={'start_token': self.start_token, 'captioning_mask': captioning_mask})
         | 
| 546 | 
            +
             | 
| 547 | 
            +
                    processed_results = []
         | 
| 548 | 
            +
                    for idx, batch_data in enumerate(batched_inputs):
         | 
| 549 | 
            +
                        processed_results.append({})
         | 
| 550 | 
            +
                        processed_results[-1]["captioning_token"] = outputs['pred_captionings'][idx]
         | 
| 551 | 
            +
                        processed_results[-1]["captioning_text"] = outputs['pred_texts'][idx].split('.')[0]
         | 
| 552 | 
            +
                        processed_results[-1]["image_id"] = batched_inputs[idx]['image_id']
         | 
| 553 | 
            +
                        
         | 
| 554 | 
            +
                    return processed_results
         | 
| 555 | 
            +
             | 
| 556 | 
            +
                def evaluate_classification(self, batched_inputs):
         | 
| 557 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 558 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 559 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 560 | 
            +
                    img_bs = images.tensor.shape[0]
         | 
| 561 | 
            +
                    
         | 
| 562 | 
            +
                    targets = targets_grounding = queries_grounding = None
         | 
| 563 | 
            +
                    features = self.backbone(images.tensor)
         | 
| 564 | 
            +
                    outputs = self.sem_seg_head(features, target_queries=queries_grounding)
         | 
| 565 | 
            +
             | 
| 566 | 
            +
                    processed_results = []
         | 
| 567 | 
            +
                    for idx, batch_data in enumerate(batched_inputs):
         | 
| 568 | 
            +
                        processed_results.append({})
         | 
| 569 | 
            +
                        processed_results[-1]["pred_class"] = outputs['pred_logits'][idx,-1]
         | 
| 570 | 
            +
                    return processed_results
         | 
| 571 | 
            +
             | 
| 572 | 
            +
                def evaluate_grounding_baseline(self, batched_inputs, mode):
         | 
| 573 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 574 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 575 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 576 | 
            +
                    img_bs = images.tensor.shape[0]
         | 
| 577 | 
            +
                    
         | 
| 578 | 
            +
                    targets = targets_grounding = queries_grounding = None
         | 
| 579 | 
            +
                    features = self.backbone(images.tensor)
         | 
| 580 | 
            +
                    outputs = self.sem_seg_head(features, target_queries=queries_grounding)
         | 
| 581 | 
            +
             | 
| 582 | 
            +
                    mask_pred_results = outputs["pred_masks"]
         | 
| 583 | 
            +
                    caption_pred_results = outputs["pred_captions"] if self.task_switch['caption'] else [None for i in range(len(mask_pred_results))]
         | 
| 584 | 
            +
             | 
| 585 | 
            +
                    # upsample masks
         | 
| 586 | 
            +
                    mask_pred_results = F.interpolate(
         | 
| 587 | 
            +
                        mask_pred_results,
         | 
| 588 | 
            +
                        size=(images.tensor.shape[-2], images.tensor.shape[-1]),
         | 
| 589 | 
            +
                        mode="bicubic",
         | 
| 590 | 
            +
                        align_corners=False,
         | 
| 591 | 
            +
                        antialias=True
         | 
| 592 | 
            +
                    )
         | 
| 593 | 
            +
             | 
| 594 | 
            +
                    processed_results = []
         | 
| 595 | 
            +
                    for mask_pred_result, caption_pred_result, input_per_image, image_size in zip(
         | 
| 596 | 
            +
                        mask_pred_results, caption_pred_results, batched_inputs, images.image_sizes
         | 
| 597 | 
            +
                    ):
         | 
| 598 | 
            +
                        height = input_per_image.get("height", image_size[0])
         | 
| 599 | 
            +
                        width = input_per_image.get("width", image_size[1])
         | 
| 600 | 
            +
                        processed_results.append({})
         | 
| 601 | 
            +
             | 
| 602 | 
            +
                        mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
         | 
| 603 | 
            +
                            mask_pred_result, image_size, height, width
         | 
| 604 | 
            +
                        )[:-1]
         | 
| 605 | 
            +
             | 
| 606 | 
            +
                        texts_all = input_per_image['groundings']['texts']
         | 
| 607 | 
            +
                        grd_masks = []
         | 
| 608 | 
            +
                        for texts in texts_all:
         | 
| 609 | 
            +
                            if mode == 'grounding_refcoco':
         | 
| 610 | 
            +
                                self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(texts, name='grounding', prompt=False, is_eval=True)
         | 
| 611 | 
            +
                            elif mode == 'grounding_phrasecut':
         | 
| 612 | 
            +
                                self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(texts, name='grounding', prompt=True, is_eval=False)
         | 
| 613 | 
            +
                            t_emb = getattr(self.sem_seg_head.predictor.lang_encoder, "{}_text_embeddings".format('grounding')).t()
         | 
| 614 | 
            +
                            v_emb = caption_pred_result[:-1]
         | 
| 615 | 
            +
                            v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 616 | 
            +
                            vt_sim = v_emb @ t_emb
         | 
| 617 | 
            +
                            max_id = vt_sim.max(0)[1][0]
         | 
| 618 | 
            +
                            grd_masks += [mask_pred_result[max_id]]
         | 
| 619 | 
            +
                        processed_results[-1]['grounding_mask'] = torch.stack(grd_masks)
         | 
| 620 | 
            +
             | 
| 621 | 
            +
                    return processed_results
         | 
| 622 | 
            +
             | 
| 623 | 
            +
                def evaluate_grounding(self, batched_inputs, mode):
         | 
| 624 | 
            +
                    images = [x["image"].to(self.device) for x in batched_inputs]
         | 
| 625 | 
            +
                    images = [(x - self.pixel_mean) / self.pixel_std for x in images]
         | 
| 626 | 
            +
                    images = ImageList.from_tensors(images, self.size_divisibility)
         | 
| 627 | 
            +
             | 
| 628 | 
            +
                    extra = {}
         | 
| 629 | 
            +
                    # mask_pred_results = []
         | 
| 630 | 
            +
                    # for idx, batch_per_image in enumerate(batched_inputs):
         | 
| 631 | 
            +
                    #     grd_texts = batch_per_image['groundings']['texts']
         | 
| 632 | 
            +
                    #     grd_masks = []
         | 
| 633 | 
            +
                    #     for anno_text in grd_texts:
         | 
| 634 | 
            +
                    #         gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
         | 
| 635 | 
            +
                    #         token_emb = gtext['token_emb']
         | 
| 636 | 
            +
                    #         tokens = gtext['tokens']
         | 
| 637 | 
            +
                        
         | 
| 638 | 
            +
                    #         grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
         | 
| 639 | 
            +
                    #         extra['grounding_tokens'] = grd_emb[:,None]
         | 
| 640 | 
            +
             | 
| 641 | 
            +
                    #         assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
         | 
| 642 | 
            +
                    #         features = self.backbone(images.tensor)
         | 
| 643 | 
            +
                    #         outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
         | 
| 644 | 
            +
                            
         | 
| 645 | 
            +
                    #         pred_gmasks = outputs['pred_masks'][idx,self.num_queries:2*self.num_queries-1]
         | 
| 646 | 
            +
                    #         v_emb = outputs['pred_captions'][idx,self.num_queries:2*self.num_queries-1]
         | 
| 647 | 
            +
                    #         t_emb = grd_emb[-1:]
         | 
| 648 | 
            +
             | 
| 649 | 
            +
                    #         t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 650 | 
            +
                    #         v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)            
         | 
| 651 | 
            +
             | 
| 652 | 
            +
                    #         temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
         | 
| 653 | 
            +
                    #         out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
         | 
| 654 | 
            +
                            
         | 
| 655 | 
            +
                    #         matched_id = out_prob.max(0)[1]
         | 
| 656 | 
            +
                    #         grd_masks += [pred_gmasks[matched_id,:,:]]
         | 
| 657 | 
            +
                    #     mask_pred_results += [torch.cat(grd_masks)]
         | 
| 658 | 
            +
             | 
| 659 | 
            +
                    # comment for multi object inference.
         | 
| 660 | 
            +
                    mask_pred_results = []
         | 
| 661 | 
            +
                    for idx, batch_per_image in enumerate(batched_inputs):
         | 
| 662 | 
            +
                        grd_texts = batch_per_image['groundings']['texts']
         | 
| 663 | 
            +
                        grd_texts = [x[0] for x in grd_texts]
         | 
| 664 | 
            +
             | 
| 665 | 
            +
                        gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
         | 
| 666 | 
            +
                        token_emb = gtext['token_emb']
         | 
| 667 | 
            +
                        tokens = gtext['tokens']
         | 
| 668 | 
            +
                        query_emb = token_emb[tokens['attention_mask'].bool()]
         | 
| 669 | 
            +
                        extra['grounding_tokens'] = query_emb[:,None]
         | 
| 670 | 
            +
             | 
| 671 | 
            +
                        features = self.backbone(images.tensor)
         | 
| 672 | 
            +
                        outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
         | 
| 673 | 
            +
             | 
| 674 | 
            +
                        pred_gmasks = outputs['pred_masks'][idx,self.num_queries:2*self.num_queries-1]
         | 
| 675 | 
            +
                        v_emb = outputs['pred_captions'][idx,self.num_queries:2*self.num_queries-1]
         | 
| 676 | 
            +
                        t_emb = gtext['class_emb']
         | 
| 677 | 
            +
             | 
| 678 | 
            +
                        t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 679 | 
            +
                        v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)            
         | 
| 680 | 
            +
             | 
| 681 | 
            +
                        temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
         | 
| 682 | 
            +
                        out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
         | 
| 683 | 
            +
                        
         | 
| 684 | 
            +
                        matched_id = out_prob.max(0)[1]
         | 
| 685 | 
            +
                        mask_pred_results += [pred_gmasks[matched_id,:,:]]
         | 
| 686 | 
            +
             | 
| 687 | 
            +
                    for i in range(len(mask_pred_results)):
         | 
| 688 | 
            +
                        # upsample masks
         | 
| 689 | 
            +
                        mask_pred_results[i] = F.interpolate(
         | 
| 690 | 
            +
                            mask_pred_results[i][None,],
         | 
| 691 | 
            +
                            size=(images.tensor.shape[-2], images.tensor.shape[-1]),
         | 
| 692 | 
            +
                            mode="bicubic",
         | 
| 693 | 
            +
                            align_corners=False,
         | 
| 694 | 
            +
                            antialias=True
         | 
| 695 | 
            +
                        )[0]
         | 
| 696 | 
            +
             | 
| 697 | 
            +
                    processed_results = []
         | 
| 698 | 
            +
                    for mask_pred_result, input_per_image, image_size in zip(
         | 
| 699 | 
            +
                        mask_pred_results, batched_inputs, images.image_sizes
         | 
| 700 | 
            +
                    ):
         | 
| 701 | 
            +
                        height = input_per_image.get("height", image_size[0])
         | 
| 702 | 
            +
                        width = input_per_image.get("width", image_size[1])
         | 
| 703 | 
            +
                        processed_results.append({})
         | 
| 704 | 
            +
             | 
| 705 | 
            +
                        mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
         | 
| 706 | 
            +
                            mask_pred_result, image_size, height, width
         | 
| 707 | 
            +
                        )
         | 
| 708 | 
            +
                        processed_results[-1]['grounding_mask'] = mask_pred_result
         | 
| 709 | 
            +
             | 
| 710 | 
            +
                        # compute bbox
         | 
| 711 | 
            +
                        # bbox = BitMasks(mask_pred_result > 0).get_bounding_boxes()
         | 
| 712 | 
            +
                        # bbox = BoxMode.convert(bbox.tensor, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
         | 
| 713 | 
            +
                        # processed_results[-1]['grounding_box'] = bbox
         | 
| 714 | 
            +
             | 
| 715 | 
            +
                    return processed_results
         | 
| 716 | 
            +
             | 
| 717 | 
            +
                def prepare_vlp_targets(self, batched_inputs, device):
         | 
| 718 | 
            +
                    input_ids = []
         | 
| 719 | 
            +
                    attention_mask = []
         | 
| 720 | 
            +
                    for cnt, x in enumerate(batched_inputs):
         | 
| 721 | 
            +
                        captions = x['captions']
         | 
| 722 | 
            +
                        randid = random.randint(0, len(captions)-1)
         | 
| 723 | 
            +
                        input_ids += x['tokens']['input_ids'][randid:randid+1]
         | 
| 724 | 
            +
                        attention_mask += x['tokens']['attention_mask'][randid:randid+1]
         | 
| 725 | 
            +
             | 
| 726 | 
            +
                    input_ids = torch.stack(input_ids)
         | 
| 727 | 
            +
                    attention_mask = torch.stack(attention_mask)
         | 
| 728 | 
            +
                    tokens = {"input_ids": input_ids, "attention_mask": attention_mask}
         | 
| 729 | 
            +
                    lang_results = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(tokens, token=True)
         | 
| 730 | 
            +
             | 
| 731 | 
            +
                    target_vlp = []
         | 
| 732 | 
            +
                    for cnt, x in enumerate(batched_inputs):
         | 
| 733 | 
            +
                        target_dict = {}
         | 
| 734 | 
            +
                        target_dict["caption_tokens"] = lang_results['token_emb'][cnt:cnt+1]
         | 
| 735 | 
            +
                        target_dict["caption_proj"] = lang_results['class_emb'][cnt:cnt+1]
         | 
| 736 | 
            +
                        target_dict["caption_tokenids"] = lang_results['tokens']['input_ids'][cnt:cnt+1]
         | 
| 737 | 
            +
                        target_dict["caption_mask"] = lang_results['tokens']['attention_mask'][cnt:cnt+1]            
         | 
| 738 | 
            +
                        target_vlp.append(target_dict)
         | 
| 739 | 
            +
                    return target_vlp
         | 
| 740 | 
            +
                
         | 
| 741 | 
            +
                def prepare_targets(self, batched_inputs, images):
         | 
| 742 | 
            +
                    h_pad, w_pad = images.tensor.shape[-2:]
         | 
| 743 | 
            +
                    new_targets = []
         | 
| 744 | 
            +
                    for idx, batch_per_image in enumerate(batched_inputs):
         | 
| 745 | 
            +
                        targets_per_image = batch_per_image["instances"].to(self.device)
         | 
| 746 | 
            +
             | 
| 747 | 
            +
                        # pad gt
         | 
| 748 | 
            +
                        gt_masks = targets_per_image.gt_masks
         | 
| 749 | 
            +
                        padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
         | 
| 750 | 
            +
                        padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
         | 
| 751 | 
            +
             | 
| 752 | 
            +
                        gt_boxes = targets_per_image.gt_boxes.tensor
         | 
| 753 | 
            +
                        ratio = torch.tensor([w_pad,h_pad,w_pad,h_pad]).to(gt_boxes.device)[None,:]
         | 
| 754 | 
            +
                        gt_boxes = gt_boxes / ratio
         | 
| 755 | 
            +
                        xc,yc,w,h = (gt_boxes[:,0] + gt_boxes[:,2])/2, (gt_boxes[:,1] + gt_boxes[:,3])/2, gt_boxes[:,2] - gt_boxes[:,0], gt_boxes[:,3] - gt_boxes[:,1]
         | 
| 756 | 
            +
                        gt_boxes = torch.stack([xc,yc,w,h]).permute(1,0)
         | 
| 757 | 
            +
             | 
| 758 | 
            +
                        target_dict = {
         | 
| 759 | 
            +
                                "labels": targets_per_image.gt_classes,
         | 
| 760 | 
            +
                                "is_things": targets_per_image.is_things,
         | 
| 761 | 
            +
                                "masks": padded_masks,
         | 
| 762 | 
            +
                                "boxes": gt_boxes
         | 
| 763 | 
            +
                                }
         | 
| 764 | 
            +
             | 
| 765 | 
            +
                        if self.task_switch['caption']:
         | 
| 766 | 
            +
                            caption = batch_per_image["captions"]
         | 
| 767 | 
            +
                            caption_noun = batch_per_image["captions_noun"]
         | 
| 768 | 
            +
                            rand_index = random.randint(0, len(caption)-1)
         | 
| 769 | 
            +
             | 
| 770 | 
            +
                            text = caption[rand_index]
         | 
| 771 | 
            +
                            nouns = caption_noun[rand_index]
         | 
| 772 | 
            +
                            noun_captions = [prompt_engineering(noun, topk=10000, suffix='.') for noun in nouns] + [text]
         | 
| 773 | 
            +
                            
         | 
| 774 | 
            +
                            self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(noun_captions, is_eval=False, name='caption_noun', prompt=False)
         | 
| 775 | 
            +
                            ctext = getattr(self.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('caption_noun'))
         | 
| 776 | 
            +
                            target_dict["captions"] = ctext
         | 
| 777 | 
            +
                            
         | 
| 778 | 
            +
                            target_dict["captions_hash"] = [(hash(st.stem(txt)) % 10**16) for txt in (nouns + [text])]
         | 
| 779 | 
            +
                            target_dict["labels_hash"] = [(hash(st.stem(COCO_PANOPTIC_CLASSES[label_id].replace('-other','').replace('-merged','').replace('-stuff',''))) % 10**16) for label_id in target_dict['labels']]
         | 
| 780 | 
            +
                            
         | 
| 781 | 
            +
                        if self.task_switch['grounding']:
         | 
| 782 | 
            +
                            grd_masks = batch_per_image['groundings']['masks']
         | 
| 783 | 
            +
                            grd_texts = batch_per_image['groundings']['texts']
         | 
| 784 | 
            +
                            grd_hash = batch_per_image['groundings']['hash']
         | 
| 785 | 
            +
                            grd_task = batch_per_image['groundings']['mode']
         | 
| 786 | 
            +
                            
         | 
| 787 | 
            +
                            if len(grd_masks) == 0:
         | 
| 788 | 
            +
                                padded_masks = None
         | 
| 789 | 
            +
                            else:
         | 
| 790 | 
            +
                                padded_masks = torch.zeros((grd_masks.shape[0], h_pad, w_pad), dtype=grd_masks.dtype, device=grd_masks.device)
         | 
| 791 | 
            +
                                padded_masks[:, : grd_masks.shape[1], : grd_masks.shape[2]] = grd_masks
         | 
| 792 | 
            +
             | 
| 793 | 
            +
                            gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
         | 
| 794 | 
            +
                            token_emb = gtext['token_emb']
         | 
| 795 | 
            +
                            tokens = gtext['tokens']
         | 
| 796 | 
            +
                            
         | 
| 797 | 
            +
                            unique_hash_id = np.unique(grd_hash, return_index=True)[1]
         | 
| 798 | 
            +
                            selected_mask = np.zeros(len(grd_hash)).astype(np.bool)
         | 
| 799 | 
            +
                            selected_mask[unique_hash_id] = True
         | 
| 800 | 
            +
             | 
| 801 | 
            +
                            selected_token_emb = token_emb[selected_mask]
         | 
| 802 | 
            +
                            selected_attn_mask = tokens['attention_mask'][selected_mask]
         | 
| 803 | 
            +
                            query_emb = selected_token_emb[selected_attn_mask.bool()]
         | 
| 804 | 
            +
                            
         | 
| 805 | 
            +
                            class_idx = tokens['attention_mask'].sum(dim=-1) - 1
         | 
| 806 | 
            +
                            class_idx = torch.stack((torch.arange(len(class_idx), device=class_idx.device), class_idx)).tolist()
         | 
| 807 | 
            +
                            class_emb = token_emb[class_idx]
         | 
| 808 | 
            +
                            
         | 
| 809 | 
            +
                            target_dict['grounding_masks'] = padded_masks
         | 
| 810 | 
            +
                            target_dict['grounding_query_embs'] = query_emb
         | 
| 811 | 
            +
                            target_dict['grounding_class_embs'] = class_emb
         | 
| 812 | 
            +
                            target_dict['grounding_hash'] = grd_hash
         | 
| 813 | 
            +
                            target_dict['grounding_task'] = grd_task
         | 
| 814 | 
            +
             | 
| 815 | 
            +
                        new_targets.append(target_dict)
         | 
| 816 | 
            +
                    return new_targets
         | 
| 817 | 
            +
             | 
| 818 | 
            +
                def semantic_inference(self, mask_cls, mask_pred, keep_sem_bgd=False):
         | 
| 819 | 
            +
                    if keep_sem_bgd:
         | 
| 820 | 
            +
                        mask_cls = F.softmax(mask_cls, dim=-1)
         | 
| 821 | 
            +
                    else:
         | 
| 822 | 
            +
                        mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
         | 
| 823 | 
            +
                    mask_pred = mask_pred.sigmoid()
         | 
| 824 | 
            +
                    semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
         | 
| 825 | 
            +
                    return semseg
         | 
| 826 | 
            +
             | 
| 827 | 
            +
                def panoptic_inference(self, mask_cls, mask_pred):
         | 
| 828 | 
            +
                    scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
         | 
| 829 | 
            +
                    mask_pred = mask_pred.sigmoid()
         | 
| 830 | 
            +
             | 
| 831 | 
            +
                    keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
         | 
| 832 | 
            +
                    cur_scores = scores[keep]
         | 
| 833 | 
            +
                    cur_classes = labels[keep]
         | 
| 834 | 
            +
                    cur_masks = mask_pred[keep]
         | 
| 835 | 
            +
                    cur_mask_cls = mask_cls[keep]
         | 
| 836 | 
            +
                    cur_mask_cls = cur_mask_cls[:, :-1]
         | 
| 837 | 
            +
                    cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
         | 
| 838 | 
            +
             | 
| 839 | 
            +
                    h, w = cur_masks.shape[-2:]
         | 
| 840 | 
            +
                    panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
         | 
| 841 | 
            +
                    segments_info = []
         | 
| 842 | 
            +
             | 
| 843 | 
            +
                    current_segment_id = 0
         | 
| 844 | 
            +
             | 
| 845 | 
            +
                    if cur_masks.shape[0] == 0:
         | 
| 846 | 
            +
                        # We didn't detect any mask :(
         | 
| 847 | 
            +
                        return panoptic_seg, segments_info
         | 
| 848 | 
            +
                    else:
         | 
| 849 | 
            +
                        # take argmax
         | 
| 850 | 
            +
                        cur_mask_ids = cur_prob_masks.argmax(0)
         | 
| 851 | 
            +
                        stuff_memory_list = {}
         | 
| 852 | 
            +
                        thing_dataset_id_to_contiguous_id = self.metadata.thing_dataset_id_to_contiguous_id if hasattr(self.metadata, 'thing_dataset_id_to_contiguous_id') else {}
         | 
| 853 | 
            +
                        for k in range(cur_classes.shape[0]):
         | 
| 854 | 
            +
                            pred_class = cur_classes[k].item()
         | 
| 855 | 
            +
                            isthing = pred_class in thing_dataset_id_to_contiguous_id.values()
         | 
| 856 | 
            +
                            mask_area = (cur_mask_ids == k).sum().item()
         | 
| 857 | 
            +
                            original_area = (cur_masks[k] >= 0.5).sum().item()
         | 
| 858 | 
            +
                            mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
         | 
| 859 | 
            +
             | 
| 860 | 
            +
                            if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
         | 
| 861 | 
            +
                                if mask_area / original_area < self.overlap_threshold:
         | 
| 862 | 
            +
                                    continue
         | 
| 863 | 
            +
             | 
| 864 | 
            +
                                # merge stuff regions
         | 
| 865 | 
            +
                                if not isthing:
         | 
| 866 | 
            +
                                    if int(pred_class) in stuff_memory_list.keys():
         | 
| 867 | 
            +
                                        panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
         | 
| 868 | 
            +
                                        continue
         | 
| 869 | 
            +
                                    else:
         | 
| 870 | 
            +
                                        stuff_memory_list[int(pred_class)] = current_segment_id + 1
         | 
| 871 | 
            +
             | 
| 872 | 
            +
                                current_segment_id += 1
         | 
| 873 | 
            +
                                panoptic_seg[mask] = current_segment_id
         | 
| 874 | 
            +
             | 
| 875 | 
            +
                                segments_info.append(
         | 
| 876 | 
            +
                                    {
         | 
| 877 | 
            +
                                        "id": current_segment_id,
         | 
| 878 | 
            +
                                        "isthing": bool(isthing),
         | 
| 879 | 
            +
                                        "category_id": int(pred_class),
         | 
| 880 | 
            +
                                    }
         | 
| 881 | 
            +
                                )
         | 
| 882 | 
            +
                        return panoptic_seg, segments_info
         | 
| 883 | 
            +
             | 
| 884 | 
            +
                def instance_inference(self, mask_cls, mask_pred, box_pred):
         | 
| 885 | 
            +
                    # mask_pred is already processed to have the same shape as original input
         | 
| 886 | 
            +
                    image_size = mask_pred.shape[-2:]
         | 
| 887 | 
            +
             | 
| 888 | 
            +
                    # [Q, K]
         | 
| 889 | 
            +
                    scores = F.softmax(mask_cls, dim=-1)[:, :-1]
         | 
| 890 | 
            +
                    labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
         | 
| 891 | 
            +
                    # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
         | 
| 892 | 
            +
                    scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
         | 
| 893 | 
            +
             | 
| 894 | 
            +
                    labels_per_image = labels[topk_indices]
         | 
| 895 | 
            +
                    topk_indices = (topk_indices // self.sem_seg_head.num_classes)
         | 
| 896 | 
            +
                    # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
         | 
| 897 | 
            +
                    mask_pred = mask_pred[topk_indices]
         | 
| 898 | 
            +
                    if box_pred is not None:
         | 
| 899 | 
            +
                        box_pred = box_pred[topk_indices]
         | 
| 900 | 
            +
             | 
| 901 | 
            +
                    # if this is panoptic segmentation, we only keep the "thing" classes
         | 
| 902 | 
            +
                    if self.panoptic_on:
         | 
| 903 | 
            +
                        thing_dataset_id_to_contiguous_id = self.metadata.thing_dataset_id_to_contiguous_id if hasattr(self.metadata, 'thing_dataset_id_to_contiguous_id') else {}
         | 
| 904 | 
            +
                        keep = torch.zeros_like(scores_per_image).bool()
         | 
| 905 | 
            +
                        for i, lab in enumerate(labels_per_image):
         | 
| 906 | 
            +
                            keep[i] = lab in thing_dataset_id_to_contiguous_id.values()
         | 
| 907 | 
            +
             | 
| 908 | 
            +
                        scores_per_image = scores_per_image[keep]
         | 
| 909 | 
            +
                        labels_per_image = labels_per_image[keep]
         | 
| 910 | 
            +
                        mask_pred = mask_pred[keep]
         | 
| 911 | 
            +
             | 
| 912 | 
            +
                        if box_pred is not None:
         | 
| 913 | 
            +
                            box_pred = box_pred[keep]
         | 
| 914 | 
            +
             | 
| 915 | 
            +
                    result = Instances(image_size)
         | 
| 916 | 
            +
                    # mask (before sigmoid)
         | 
| 917 | 
            +
                    result.pred_masks = (mask_pred > 0).float()
         | 
| 918 | 
            +
                    # result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
         | 
| 919 | 
            +
                    # Uncomment the following to get boxes from masks (this is slow)
         | 
| 920 | 
            +
             | 
| 921 | 
            +
                    if box_pred is not None:
         | 
| 922 | 
            +
                        result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
         | 
| 923 | 
            +
                    else:
         | 
| 924 | 
            +
                        result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
         | 
| 925 | 
            +
             | 
| 926 | 
            +
                    # calculate average mask prob
         | 
| 927 | 
            +
                    mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
         | 
| 928 | 
            +
                    result.scores = scores_per_image * mask_scores_per_image
         | 
| 929 | 
            +
                    result.pred_classes = labels_per_image
         | 
| 930 | 
            +
             | 
| 931 | 
            +
                    return result
         | 
| 932 | 
            +
             | 
| 933 | 
            +
             | 
| 934 | 
            +
             | 
| 935 | 
            +
            @register_model
         | 
| 936 | 
            +
            def get_xdecoder_model(cfg, **kwargs):
         | 
| 937 | 
            +
                return GeneralizedXdecoder(cfg)
         | 
    	
        modeling/body/__init__.py
    ADDED
    
    | @@ -0,0 +1,10 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .xdecoder_head import *
         | 
| 2 | 
            +
            from .build import *
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            def build_xdecoder_head(config, *args, **kwargs):
         | 
| 5 | 
            +
                model_name = config['MODEL']['HEAD']
         | 
| 6 | 
            +
                if not is_model(model_name):
         | 
| 7 | 
            +
                    raise ValueError(f'Unkown model: {model_name}')
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                body = model_entrypoints(model_name)(config, *args, **kwargs)
         | 
| 10 | 
            +
                return body
         | 
    	
        modeling/body/build.py
    ADDED
    
    | @@ -0,0 +1,13 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            _model_entrypoints = {}
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            def register_body(fn):
         | 
| 4 | 
            +
                module_name_split = fn.__module__.split('.')
         | 
| 5 | 
            +
                model_name = module_name_split[-1]
         | 
| 6 | 
            +
                _model_entrypoints[model_name] = fn
         | 
| 7 | 
            +
                return fn
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            def model_entrypoints(model_name):
         | 
| 10 | 
            +
                return _model_entrypoints[model_name]
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            def is_model(model_name):
         | 
| 13 | 
            +
                return model_name in _model_entrypoints
         | 
    	
        modeling/body/xdecoder_head.py
    ADDED
    
    | @@ -0,0 +1,126 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # --------------------------------------------------------
         | 
| 2 | 
            +
            # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
         | 
| 3 | 
            +
            # Copyright (c) 2022 Microsoft
         | 
| 4 | 
            +
            # Licensed under The MIT License [see LICENSE for details]
         | 
| 5 | 
            +
            # Written by Xueyan Zou (xueyan@cs.wisc.edu)
         | 
| 6 | 
            +
            # --------------------------------------------------------
         | 
| 7 | 
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         | 
| 8 | 
            +
            from typing import Dict
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from torch import nn
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from detectron2.layers import ShapeSpec
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from .build import register_body
         | 
| 15 | 
            +
            from ..vision.encoder import build_encoder
         | 
| 16 | 
            +
            from ..interface import build_decoder
         | 
| 17 | 
            +
            from ..utils import configurable
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            class XdecoderHead(nn.Module):
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                @configurable
         | 
| 23 | 
            +
                def __init__(
         | 
| 24 | 
            +
                    self,
         | 
| 25 | 
            +
                    input_shape: Dict[str, ShapeSpec],
         | 
| 26 | 
            +
                    *,
         | 
| 27 | 
            +
                    num_classes: int,
         | 
| 28 | 
            +
                    pixel_decoder: nn.Module,
         | 
| 29 | 
            +
                    loss_weight: float = 1.0,
         | 
| 30 | 
            +
                    ignore_value: int = -1,
         | 
| 31 | 
            +
                    # extra parameters
         | 
| 32 | 
            +
                    transformer_predictor: nn.Module,
         | 
| 33 | 
            +
                    transformer_in_feature: str,
         | 
| 34 | 
            +
                    binary_classes: bool,
         | 
| 35 | 
            +
                ):
         | 
| 36 | 
            +
                    """
         | 
| 37 | 
            +
                    NOTE: this interface is experimental.
         | 
| 38 | 
            +
                    Args:
         | 
| 39 | 
            +
                        input_shape: shapes (channels and stride) of the input features
         | 
| 40 | 
            +
                        num_classes: number of classes to predict
         | 
| 41 | 
            +
                        pixel_decoder: the pixel decoder module
         | 
| 42 | 
            +
                        loss_weight: loss weight
         | 
| 43 | 
            +
                        ignore_value: category id to be ignored during training.
         | 
| 44 | 
            +
                        transformer_predictor: the transformer decoder that makes prediction
         | 
| 45 | 
            +
                        transformer_in_feature: input feature name to the transformer_predictor
         | 
| 46 | 
            +
                    """
         | 
| 47 | 
            +
                    super().__init__()
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
         | 
| 50 | 
            +
                    self.in_features = [k for k, v in input_shape]
         | 
| 51 | 
            +
                    feature_strides = [v.stride for k, v in input_shape]
         | 
| 52 | 
            +
                    feature_channels = [v.channels for k, v in input_shape]
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    self.ignore_value = ignore_value
         | 
| 55 | 
            +
                    self.common_stride = 4
         | 
| 56 | 
            +
                    self.loss_weight = loss_weight
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    self.pixel_decoder = pixel_decoder
         | 
| 59 | 
            +
                    self.predictor = transformer_predictor
         | 
| 60 | 
            +
                    self.transformer_in_feature = transformer_in_feature
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    self.num_classes = num_classes
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    if binary_classes:
         | 
| 65 | 
            +
                        self.num_classes = 1
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                @classmethod
         | 
| 68 | 
            +
                def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec], lang_encoder: nn.Module, extra: dict):
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    in_features_type = cfg['MODEL']['DECODER']['TRANSFORMER_IN_FEATURE']
         | 
| 71 | 
            +
                    enc_cfg = cfg['MODEL']['ENCODER']
         | 
| 72 | 
            +
                    dec_cfg = cfg['MODEL']['DECODER']
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    # figure out in_channels to transformer predictor
         | 
| 75 | 
            +
                    if in_features_type == "transformer_encoder":
         | 
| 76 | 
            +
                        transformer_predictor_in_channels = enc_cfg['CONVS_DIM']
         | 
| 77 | 
            +
                    elif in_features_type == "pixel_embedding":
         | 
| 78 | 
            +
                        transformer_predictor_in_channels = enc_cfg['MASK_DIM']
         | 
| 79 | 
            +
                    elif in_features_type == "multi_scale_pixel_decoder":
         | 
| 80 | 
            +
                        transformer_predictor_in_channels = enc_cfg['CONVS_DIM']
         | 
| 81 | 
            +
                    else:
         | 
| 82 | 
            +
                        transformer_predictor_in_channels = input_shape[dec_cfg['TRANSFORMER_IN_FEATURE']].channels
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    return {
         | 
| 85 | 
            +
                        "input_shape": {
         | 
| 86 | 
            +
                            k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES']
         | 
| 87 | 
            +
                        },
         | 
| 88 | 
            +
                        "ignore_value": enc_cfg['IGNORE_VALUE'],
         | 
| 89 | 
            +
                        "num_classes": enc_cfg.get('NUM_CLASSES', None),
         | 
| 90 | 
            +
                        "pixel_decoder": build_encoder(cfg, input_shape),
         | 
| 91 | 
            +
                        "loss_weight": enc_cfg['LOSS_WEIGHT'],
         | 
| 92 | 
            +
                        "transformer_in_feature": dec_cfg['TRANSFORMER_IN_FEATURE'],
         | 
| 93 | 
            +
                        "transformer_predictor": build_decoder(
         | 
| 94 | 
            +
                            cfg,
         | 
| 95 | 
            +
                            transformer_predictor_in_channels,
         | 
| 96 | 
            +
                            lang_encoder,
         | 
| 97 | 
            +
                            mask_classification=True,
         | 
| 98 | 
            +
                            extra=extra,
         | 
| 99 | 
            +
                        ),
         | 
| 100 | 
            +
                        "binary_classes": enc_cfg['BINARY_CLASSES']
         | 
| 101 | 
            +
                    }
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                def forward(self, features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
         | 
| 104 | 
            +
                    return self.layers(features, mask, target_queries, target_vlp, task, extra)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                def layers(self, features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
         | 
| 107 | 
            +
                    mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(features)
         | 
| 108 | 
            +
                    
         | 
| 109 | 
            +
                    if self.transformer_in_feature == "multi_scale_pixel_decoder":
         | 
| 110 | 
            +
                        predictions = self.predictor(multi_scale_features, mask_features, mask, target_queries, target_vlp, task, extra)
         | 
| 111 | 
            +
                    else:
         | 
| 112 | 
            +
                        if self.transformer_in_feature == "transformer_encoder":
         | 
| 113 | 
            +
                            assert (
         | 
| 114 | 
            +
                                transformer_encoder_features is not None
         | 
| 115 | 
            +
                            ), "Please use the TransformerEncoderPixelDecoder."
         | 
| 116 | 
            +
                            predictions = self.predictor(transformer_encoder_features, mask_features, mask)
         | 
| 117 | 
            +
                        elif self.transformer_in_feature == "pixel_embedding":
         | 
| 118 | 
            +
                            predictions = self.predictor(mask_features, mask_features, mask)
         | 
| 119 | 
            +
                        else:
         | 
| 120 | 
            +
                            predictions = self.predictor(features[self.transformer_in_feature], mask_features, mask)
         | 
| 121 | 
            +
                    return predictions
         | 
| 122 | 
            +
             | 
| 123 | 
            +
             | 
| 124 | 
            +
            @register_body
         | 
| 125 | 
            +
            def get_xdecoder_head(cfg, input_shape, lang_encoder, extra):
         | 
| 126 | 
            +
                return XdecoderHead(cfg, input_shape, lang_encoder, extra)
         | 
    	
        modeling/interface/__init__.py
    ADDED
    
    | @@ -0,0 +1,13 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .xdecoder import *
         | 
| 2 | 
            +
            from .seem_v0 import *
         | 
| 3 | 
            +
            from .seem_v1 import *
         | 
| 4 | 
            +
            from .seem_demo import *
         | 
| 5 | 
            +
            from .build import *
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            def build_decoder(config, *args, **kwargs):
         | 
| 8 | 
            +
                model_name = config['MODEL']['DECODER']['NAME']
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                if not is_model(model_name):
         | 
| 11 | 
            +
                    raise ValueError(f'Unkown model: {model_name}')
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                return model_entrypoints(model_name)(config, *args, **kwargs)
         | 
    	
        modeling/interface/build.py
    ADDED
    
    | @@ -0,0 +1,14 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            _model_entrypoints = {}
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            def register_decoder(fn):
         | 
| 5 | 
            +
                module_name_split = fn.__module__.split('.')
         | 
| 6 | 
            +
                model_name = module_name_split[-1]
         | 
| 7 | 
            +
                _model_entrypoints[model_name] = fn
         | 
| 8 | 
            +
                return fn
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            def model_entrypoints(model_name):
         | 
| 11 | 
            +
                return _model_entrypoints[model_name]
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            def is_model(model_name):
         | 
| 14 | 
            +
                return model_name in _model_entrypoints
         | 
    	
        modeling/interface/modules.py
    ADDED
    
    | @@ -0,0 +1,200 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Optional
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from torch import nn, Tensor
         | 
| 5 | 
            +
            from torch.nn import functional as F
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from timm.models.layers import trunc_normal_
         | 
| 8 | 
            +
            from detectron2.layers import Conv2d
         | 
| 9 | 
            +
            import fvcore.nn.weight_init as weight_init
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from ..utils import MultiheadAttention
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class SelfAttentionLayer(nn.Module):
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                def __init__(self, d_model, nhead, dropout=0.0,
         | 
| 17 | 
            +
                             activation="relu", normalize_before=False):
         | 
| 18 | 
            +
                    super().__init__()
         | 
| 19 | 
            +
                    self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                    self.norm = nn.LayerNorm(d_model)
         | 
| 22 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                    self.activation = _get_activation_fn(activation)
         | 
| 25 | 
            +
                    self.normalize_before = normalize_before
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    self._reset_parameters()
         | 
| 28 | 
            +
                
         | 
| 29 | 
            +
                def _reset_parameters(self):
         | 
| 30 | 
            +
                    for p in self.parameters():
         | 
| 31 | 
            +
                        if p.dim() > 1:
         | 
| 32 | 
            +
                            nn.init.xavier_uniform_(p)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                def with_pos_embed(self, tensor, pos: Optional[Tensor]):
         | 
| 35 | 
            +
                    return tensor if pos is None else tensor + pos
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                def forward_post(self, tgt,
         | 
| 38 | 
            +
                                 tgt_mask: Optional[Tensor] = None,
         | 
| 39 | 
            +
                                 tgt_key_padding_mask: Optional[Tensor] = None,
         | 
| 40 | 
            +
                                 query_pos: Optional[Tensor] = None):
         | 
| 41 | 
            +
                    
         | 
| 42 | 
            +
                    q = k = self.with_pos_embed(tgt, query_pos)
         | 
| 43 | 
            +
                    tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
         | 
| 44 | 
            +
                                          key_padding_mask=tgt_key_padding_mask)[0]
         | 
| 45 | 
            +
                    tgt = tgt + self.dropout(tgt2)
         | 
| 46 | 
            +
                    tgt = self.norm(tgt)
         | 
| 47 | 
            +
                    return tgt
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                def forward_pre(self, tgt,
         | 
| 50 | 
            +
                                tgt_mask: Optional[Tensor] = None,
         | 
| 51 | 
            +
                                tgt_key_padding_mask: Optional[Tensor] = None,
         | 
| 52 | 
            +
                                query_pos: Optional[Tensor] = None):
         | 
| 53 | 
            +
                    tgt2 = self.norm(tgt)
         | 
| 54 | 
            +
                    q = k = self.with_pos_embed(tgt2, query_pos)
         | 
| 55 | 
            +
                    tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
         | 
| 56 | 
            +
                                          key_padding_mask=tgt_key_padding_mask)[0]
         | 
| 57 | 
            +
                    tgt = tgt + self.dropout(tgt2)
         | 
| 58 | 
            +
                    
         | 
| 59 | 
            +
                    return tgt
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                def forward(self, tgt,
         | 
| 62 | 
            +
                            tgt_mask: Optional[Tensor] = None,
         | 
| 63 | 
            +
                            tgt_key_padding_mask: Optional[Tensor] = None,
         | 
| 64 | 
            +
                            query_pos: Optional[Tensor] = None):
         | 
| 65 | 
            +
                    if self.normalize_before:
         | 
| 66 | 
            +
                        return self.forward_pre(tgt, tgt_mask,
         | 
| 67 | 
            +
                                                tgt_key_padding_mask, query_pos)
         | 
| 68 | 
            +
                    return self.forward_post(tgt, tgt_mask,
         | 
| 69 | 
            +
                                             tgt_key_padding_mask, query_pos)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            class CrossAttentionLayer(nn.Module):
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                def __init__(self, d_model, nhead, dropout=0.0,
         | 
| 75 | 
            +
                             activation="relu", normalize_before=False):
         | 
| 76 | 
            +
                    super().__init__()
         | 
| 77 | 
            +
                    self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    self.norm = nn.LayerNorm(d_model)
         | 
| 80 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    self.activation = _get_activation_fn(activation)
         | 
| 83 | 
            +
                    self.normalize_before = normalize_before
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    self._reset_parameters()
         | 
| 86 | 
            +
                
         | 
| 87 | 
            +
                def _reset_parameters(self):
         | 
| 88 | 
            +
                    for p in self.parameters():
         | 
| 89 | 
            +
                        if p.dim() > 1:
         | 
| 90 | 
            +
                            nn.init.xavier_uniform_(p)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                def with_pos_embed(self, tensor, pos: Optional[Tensor]):
         | 
| 93 | 
            +
                    return tensor if pos is None else tensor + pos
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                def forward_post(self, tgt, memory,
         | 
| 96 | 
            +
                                 memory_mask: Optional[Tensor] = None,
         | 
| 97 | 
            +
                                 memory_key_padding_mask: Optional[Tensor] = None,
         | 
| 98 | 
            +
                                 pos: Optional[Tensor] = None,
         | 
| 99 | 
            +
                                 query_pos: Optional[Tensor] = None):
         | 
| 100 | 
            +
                    tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
         | 
| 101 | 
            +
                                               key=self.with_pos_embed(memory, pos),
         | 
| 102 | 
            +
                                               value=memory, attn_mask=memory_mask,
         | 
| 103 | 
            +
                                               key_padding_mask=memory_key_padding_mask)
         | 
| 104 | 
            +
                    tgt = tgt + self.dropout(tgt2)
         | 
| 105 | 
            +
                    tgt = self.norm(tgt)
         | 
| 106 | 
            +
                    return tgt, avg_attn
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                def forward_pre(self, tgt, memory,
         | 
| 109 | 
            +
                                memory_mask: Optional[Tensor] = None,
         | 
| 110 | 
            +
                                memory_key_padding_mask: Optional[Tensor] = None,
         | 
| 111 | 
            +
                                pos: Optional[Tensor] = None,
         | 
| 112 | 
            +
                                query_pos: Optional[Tensor] = None):
         | 
| 113 | 
            +
                    tgt2 = self.norm(tgt)
         | 
| 114 | 
            +
                    tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
         | 
| 115 | 
            +
                                               key=self.with_pos_embed(memory, pos),
         | 
| 116 | 
            +
                                               value=memory, attn_mask=memory_mask,
         | 
| 117 | 
            +
                                               key_padding_mask=memory_key_padding_mask)
         | 
| 118 | 
            +
                    tgt = tgt + self.dropout(tgt2)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    return tgt, avg_attn
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                def forward(self, tgt, memory,
         | 
| 123 | 
            +
                            memory_mask: Optional[Tensor] = None,
         | 
| 124 | 
            +
                            memory_key_padding_mask: Optional[Tensor] = None,
         | 
| 125 | 
            +
                            pos: Optional[Tensor] = None,
         | 
| 126 | 
            +
                            query_pos: Optional[Tensor] = None):
         | 
| 127 | 
            +
                    if self.normalize_before:
         | 
| 128 | 
            +
                        return self.forward_pre(tgt, memory, memory_mask,
         | 
| 129 | 
            +
                                                memory_key_padding_mask, pos, query_pos)
         | 
| 130 | 
            +
                    return self.forward_post(tgt, memory, memory_mask,
         | 
| 131 | 
            +
                                             memory_key_padding_mask, pos, query_pos)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
             | 
| 134 | 
            +
            class FFNLayer(nn.Module):
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
         | 
| 137 | 
            +
                             activation="relu", normalize_before=False):
         | 
| 138 | 
            +
                    super().__init__()
         | 
| 139 | 
            +
                    # Implementation of Feedforward model
         | 
| 140 | 
            +
                    self.linear1 = nn.Linear(d_model, dim_feedforward)
         | 
| 141 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 142 | 
            +
                    self.linear2 = nn.Linear(dim_feedforward, d_model)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    self.norm = nn.LayerNorm(d_model)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    self.activation = _get_activation_fn(activation)
         | 
| 147 | 
            +
                    self.normalize_before = normalize_before
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    self._reset_parameters()
         | 
| 150 | 
            +
                
         | 
| 151 | 
            +
                def _reset_parameters(self):
         | 
| 152 | 
            +
                    for p in self.parameters():
         | 
| 153 | 
            +
                        if p.dim() > 1:
         | 
| 154 | 
            +
                            nn.init.xavier_uniform_(p)
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                def with_pos_embed(self, tensor, pos: Optional[Tensor]):
         | 
| 157 | 
            +
                    return tensor if pos is None else tensor + pos
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                def forward_post(self, tgt):
         | 
| 160 | 
            +
                    tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
         | 
| 161 | 
            +
                    tgt = tgt + self.dropout(tgt2)
         | 
| 162 | 
            +
                    tgt = self.norm(tgt)
         | 
| 163 | 
            +
                    return tgt
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                def forward_pre(self, tgt):
         | 
| 166 | 
            +
                    tgt2 = self.norm(tgt)
         | 
| 167 | 
            +
                    tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
         | 
| 168 | 
            +
                    tgt = tgt + self.dropout(tgt2)
         | 
| 169 | 
            +
                    return tgt
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                def forward(self, tgt):
         | 
| 172 | 
            +
                    if self.normalize_before:
         | 
| 173 | 
            +
                        return self.forward_pre(tgt)
         | 
| 174 | 
            +
                    return self.forward_post(tgt)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
             | 
| 177 | 
            +
            def _get_activation_fn(activation):
         | 
| 178 | 
            +
                """Return an activation function given a string"""
         | 
| 179 | 
            +
                if activation == "relu":
         | 
| 180 | 
            +
                    return F.relu
         | 
| 181 | 
            +
                if activation == "gelu":
         | 
| 182 | 
            +
                    return F.gelu
         | 
| 183 | 
            +
                if activation == "glu":
         | 
| 184 | 
            +
                    return F.glu
         | 
| 185 | 
            +
                raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
         | 
| 186 | 
            +
             | 
| 187 | 
            +
             | 
| 188 | 
            +
            class MLP(nn.Module):
         | 
| 189 | 
            +
                """ Very simple multi-layer perceptron (also called FFN)"""
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
         | 
| 192 | 
            +
                    super().__init__()
         | 
| 193 | 
            +
                    self.num_layers = num_layers
         | 
| 194 | 
            +
                    h = [hidden_dim] * (num_layers - 1)
         | 
| 195 | 
            +
                    self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                def forward(self, x):
         | 
| 198 | 
            +
                    for i, layer in enumerate(self.layers):
         | 
| 199 | 
            +
                        x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
         | 
| 200 | 
            +
                    return x
         | 
    	
        modeling/interface/prototype/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        modeling/interface/prototype/attention_data_struct_seemdemo.py
    ADDED
    
    | @@ -0,0 +1,265 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # --------------------------------------------------------
         | 
| 2 | 
            +
            # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
         | 
| 3 | 
            +
            # Copyright (c) 2022 Microsoft
         | 
| 4 | 
            +
            # Licensed under The MIT License [see LICENSE for details]
         | 
| 5 | 
            +
            # Written by Xueyan Zou (xueyan@cs.wisc.edu)
         | 
| 6 | 
            +
            # --------------------------------------------------------
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
            import torch.nn as nn
         | 
| 10 | 
            +
            import torch.nn.functional as F
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            predict_name_matcher = {"predictions_class": ["pred_logits"], 
         | 
| 13 | 
            +
                                    "predictions_mask":["pred_masks", "pred_gmasks", "pred_smasks"], 
         | 
| 14 | 
            +
                                    "predictions_caption":["pred_captions", "pred_gtexts"], 
         | 
| 15 | 
            +
                                    "predictions_maskemb":["pred_maskembs", "pred_smaskembs"], 
         | 
| 16 | 
            +
                                    "predictions_pos_spatial":["pred_pspatials"],
         | 
| 17 | 
            +
                                    "predictions_neg_spatial":["pred_nspatials"],
         | 
| 18 | 
            +
                                    "predictions_pos_visual":["pred_pvisuals"],
         | 
| 19 | 
            +
                                    "predictions_neg_visual":["pred_nvisuals"]}
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            predict_index_matcher = {"predictions_class": ["queries_object"], 
         | 
| 22 | 
            +
                                     "predictions_mask":["queries_object", "queries_grounding", "queries_spatial"], 
         | 
| 23 | 
            +
                                     "predictions_caption": ["queries_object", "queries_grounding"], 
         | 
| 24 | 
            +
                                     "predictions_maskemb":["queries_object", "queries_spatial"], 
         | 
| 25 | 
            +
                                     "predictions_pos_spatial":["all"],
         | 
| 26 | 
            +
                                     "predictions_neg_spatial":["all"],
         | 
| 27 | 
            +
                                     "predictions_pos_visual":["all"],
         | 
| 28 | 
            +
                                     "predictions_neg_visual":["all"]}
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            class Variable(object):
         | 
| 31 | 
            +
                '''
         | 
| 32 | 
            +
                Store dataset variable for attention
         | 
| 33 | 
            +
                output: embedding that accumuates during cross/self attention
         | 
| 34 | 
            +
                pos: positional embedding that is fixed during cross/self attention
         | 
| 35 | 
            +
                name: name of the variable
         | 
| 36 | 
            +
                type: type of the variable, e.g. queries, tokens
         | 
| 37 | 
            +
                attn_mask: attention mask for corss attention
         | 
| 38 | 
            +
                masking: masking for padding
         | 
| 39 | 
            +
                '''
         | 
| 40 | 
            +
                def __init__(self, output, name, _type, pos=None):
         | 
| 41 | 
            +
                    self.output = output
         | 
| 42 | 
            +
                    self.pos = pos
         | 
| 43 | 
            +
                    self.name = name
         | 
| 44 | 
            +
                    self.type = _type
         | 
| 45 | 
            +
                    self.attn_mask = None
         | 
| 46 | 
            +
                    self.masking = None
         | 
| 47 | 
            +
                
         | 
| 48 | 
            +
                def copy(self,):
         | 
| 49 | 
            +
                    output = self.output.clone() if self.output is not None else None
         | 
| 50 | 
            +
                    pos = self.pos.clone() if self.pos is not None else None
         | 
| 51 | 
            +
                    return Variable(output, self.name, self.type, pos)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            class AttentionDataStruct(nn.Module):
         | 
| 54 | 
            +
                '''
         | 
| 55 | 
            +
                Store dataset structure for cross/self attention
         | 
| 56 | 
            +
                task_switch: switch for different tasks
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                p_attn_variables: prototype of variables that is used in cross/self attention
         | 
| 59 | 
            +
                p_self_attn: prototype of variables that is used in self attention
         | 
| 60 | 
            +
                p_cross_attn: prototype of variables that is used in cross attention
         | 
| 61 | 
            +
                p_iter: prototype of iteration for different queries
         | 
| 62 | 
            +
                p_masking: prototype of masking for different tokens
         | 
| 63 | 
            +
                p_duplication: prototype of duplication for different quries
         | 
| 64 | 
            +
                '''
         | 
| 65 | 
            +
                def __init__(self, attn_arch, task_switch):
         | 
| 66 | 
            +
                    super(AttentionDataStruct, self).__init__()
         | 
| 67 | 
            +
                    self.task_switch = task_switch
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    # p stands for prototype
         | 
| 70 | 
            +
                    self.p_attn_variables = attn_arch['VARIABLE']
         | 
| 71 | 
            +
                    self.p_self_attn = attn_arch['SELF_ATTENTION']
         | 
| 72 | 
            +
                    self.p_cross_attn = attn_arch['CROSS_ATTENTION']
         | 
| 73 | 
            +
                    self.p_masking = attn_arch['MASKING']
         | 
| 74 | 
            +
                    self.p_duplication = attn_arch['DUPLICATION']
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    self.num_layers = attn_arch['NUM_LAYERS']
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def reset(self, flags, task, extra):
         | 
| 79 | 
            +
                    # reset variables
         | 
| 80 | 
            +
                    self.attn_variables = {}
         | 
| 81 | 
            +
                    self.cross_attn_dict = {}
         | 
| 82 | 
            +
                    self.self_attn_dict = {}
         | 
| 83 | 
            +
                    self.duplication_dict = {}
         | 
| 84 | 
            +
                    self.query_index = {}
         | 
| 85 | 
            +
                    self.output = {}
         | 
| 86 | 
            +
                    self.flags = {}
         | 
| 87 | 
            +
                    self.spatial_memory = {}
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    # initialize duplication
         | 
| 90 | 
            +
                    for key, values in self.p_duplication.items():
         | 
| 91 | 
            +
                        for name in values:
         | 
| 92 | 
            +
                            self.duplication_dict["{}_{}".format(key, name)] = self.p_duplication[key][name]
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    # initialize flag
         | 
| 95 | 
            +
                    self.flags = {"object": True}
         | 
| 96 | 
            +
                    self.flags.update(flags)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    # initialize task
         | 
| 99 | 
            +
                    self.task = task
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    # initialize output
         | 
| 102 | 
            +
                    if self.task_switch['mask']:
         | 
| 103 | 
            +
                        self.output['predictions_class'] = []
         | 
| 104 | 
            +
                        self.output['predictions_mask'] = []
         | 
| 105 | 
            +
                        self.output['predictions_maskemb'] = []
         | 
| 106 | 
            +
                    
         | 
| 107 | 
            +
                    if self.task_switch['bbox']:
         | 
| 108 | 
            +
                        self.output['predictions_bbox'] = []
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    if self.task_switch['spatial'] and ('spatial' in self.flags and self.flags['spatial']==True):
         | 
| 111 | 
            +
                        self.output['predictions_pos_spatial'] = []
         | 
| 112 | 
            +
                        self.output['predictions_neg_spatial'] = []
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    if self.task_switch['spatial'] and ('memories_spatial' in self.flags and self.flags['memories_spatial']==True):
         | 
| 115 | 
            +
                        self.spatial_memory['prev_batch_mask'] = extra['prev_mask']
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    if (self.task_switch['grounding'] and ('grounding' in self.flags and self.flags['grounding']==True)) \
         | 
| 118 | 
            +
                            or (self.task_switch['audio'] and ('audio' in self.flags and self.flags['audio']==True)):
         | 
| 119 | 
            +
                        self.output['predictions_caption'] = []
         | 
| 120 | 
            +
                    
         | 
| 121 | 
            +
                    if self.task_switch['visual']:
         | 
| 122 | 
            +
                        self.output['predictions_pos_visual'] = []
         | 
| 123 | 
            +
                        self.output['predictions_neg_visual'] = []
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    # initialize cross_attn, whether the variable is used in cross attention
         | 
| 126 | 
            +
                    for key, values in self.p_cross_attn.items():
         | 
| 127 | 
            +
                        for name in values:
         | 
| 128 | 
            +
                            self.cross_attn_dict["{}_{}".format(key, name)] = self.p_cross_attn[key][name]
         | 
| 129 | 
            +
                    
         | 
| 130 | 
            +
                    # initialize self_attn, whether the variable is used in self attention, and the interactions between queries
         | 
| 131 | 
            +
                    for key, values in self.p_self_attn.items():
         | 
| 132 | 
            +
                        for name in values:
         | 
| 133 | 
            +
                            self.self_attn_dict["{}_{}".format(key, name)] = self.p_self_attn[key][name]
         | 
| 134 | 
            +
                    
         | 
| 135 | 
            +
                    # initialize masking
         | 
| 136 | 
            +
                    self.masking = self.p_masking
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    # initialize query_index
         | 
| 139 | 
            +
                    self.query_index = {"all":[0, None]}
         | 
| 140 | 
            +
             | 
| 141 | 
            +
             | 
| 142 | 
            +
                def set(self, name, _type, output=None, pos=None, var=None):
         | 
| 143 | 
            +
                    if var is not None:
         | 
| 144 | 
            +
                        self.attn_variables[name] = var
         | 
| 145 | 
            +
                    elif name in self.duplication_dict:
         | 
| 146 | 
            +
                        assert self.duplication_dict[name] in self.attn_variables, "Duplication variable {} is not initialized yet.".format(name)
         | 
| 147 | 
            +
                        self.attn_variables[name] = self.attn_variables[self.duplication_dict[name]].copy()
         | 
| 148 | 
            +
                    else:
         | 
| 149 | 
            +
                        var = Variable(output, name, _type, pos)
         | 
| 150 | 
            +
                        self.attn_variables[name] = var
         | 
| 151 | 
            +
                
         | 
| 152 | 
            +
                def set_results(self, results):
         | 
| 153 | 
            +
                    for name in self.cross_attn_name:
         | 
| 154 | 
            +
                        self.attn_variables[name].attn_mask = results['attn_mask'][:,self.query_index[name][0]:self.query_index[name][1]]
         | 
| 155 | 
            +
                    for key in self.output:
         | 
| 156 | 
            +
                        self.output[key].append(results[key])
         | 
| 157 | 
            +
                
         | 
| 158 | 
            +
                def set_maskings(self, name, masking):
         | 
| 159 | 
            +
                    self.attn_variables[name].masking = masking
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                def cross_attn_variables(self, ):
         | 
| 162 | 
            +
                    cross_attn_name = [key for key, value in self.cross_attn_dict.items() 
         | 
| 163 | 
            +
                                       if (value==True) and (key in self.attn_variables) 
         | 
| 164 | 
            +
                                       and ((key not in self.flags) or (key in self.flags and self.flags[key]==True))]
         | 
| 165 | 
            +
                    self.cross_attn_name = cross_attn_name
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                    output = torch.cat([self.attn_variables[name].output for name in cross_attn_name])
         | 
| 168 | 
            +
                    pos_emb = torch.cat([self.attn_variables[name].pos for name in cross_attn_name])
         | 
| 169 | 
            +
                    
         | 
| 170 | 
            +
                    index = 0
         | 
| 171 | 
            +
                    for name in cross_attn_name:
         | 
| 172 | 
            +
                        self.query_index[name] = [index, index + self.attn_variables[name].output.shape[0]]
         | 
| 173 | 
            +
                        index += self.attn_variables[name].output.shape[0]
         | 
| 174 | 
            +
                    return output, pos_emb
         | 
| 175 | 
            +
                
         | 
| 176 | 
            +
                def cross_attn_mask(self, size, num_heads):
         | 
| 177 | 
            +
                    attn_mask = torch.cat([self.attn_variables[name].attn_mask for name in self.cross_attn_name], dim=1)
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    # hard code memories_spatial to previous selected mask
         | 
| 180 | 
            +
                    if 'memories_spatial' in self.cross_attn_name:
         | 
| 181 | 
            +
                        memory_attn_mask = self.spatial_memory['prev_batch_mask']
         | 
| 182 | 
            +
                        bs,c,_,_ = memory_attn_mask.shape
         | 
| 183 | 
            +
                        memory_attn_mask = F.interpolate(memory_attn_mask, size, mode='bilinear', align_corners=False)
         | 
| 184 | 
            +
                        memory_attn_mask = (memory_attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, num_heads, 1, 1).flatten(0, 1) < 0.5).bool().detach()
         | 
| 185 | 
            +
                        attn_mask[:,self.query_index['memories_spatial'][0]:self.query_index['memories_spatial'][1]] = memory_attn_mask
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
         | 
| 188 | 
            +
                    return attn_mask
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                def self_attn(self, bs, num_heads):
         | 
| 191 | 
            +
                    self_attn_name = [key for key, value in self.self_attn_dict.items() 
         | 
| 192 | 
            +
                                      if len(value)>0 and key in self.attn_variables
         | 
| 193 | 
            +
                                      and ((key not in self.flags) or (key in self.flags and self.flags[key]==True))]
         | 
| 194 | 
            +
                    self.self_attn_name = self_attn_name
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    output = torch.cat([self.attn_variables[name].output for name in self_attn_name])
         | 
| 197 | 
            +
                    pos_emb = torch.cat([self.attn_variables[name].pos for name in self_attn_name])
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    index = 0
         | 
| 200 | 
            +
                    for name in self_attn_name:
         | 
| 201 | 
            +
                        self.query_index[name] = [index, index + self.attn_variables[name].output.shape[0]]
         | 
| 202 | 
            +
                        index += self.attn_variables[name].output.shape[0]
         | 
| 203 | 
            +
                    
         | 
| 204 | 
            +
                    self_attn_mask = torch.ones((bs, output.shape[0], output.shape[0]), dtype=torch.bool, device=output.device)
         | 
| 205 | 
            +
                    self_attn_pair = []
         | 
| 206 | 
            +
                    # build self_attention mask by query interaction
         | 
| 207 | 
            +
                    for key1, value in self.self_attn_dict.items():
         | 
| 208 | 
            +
                        for key2 in value:
         | 
| 209 | 
            +
                            if key1 not in self_attn_name or key2 not in self_attn_name:
         | 
| 210 | 
            +
                                # exclude the variables that are not used in the current layer
         | 
| 211 | 
            +
                                continue
         | 
| 212 | 
            +
                            if (key1 in self.masking or key2 in self.masking) and (key1 != key2):
         | 
| 213 | 
            +
                                self_attn_pair += [[key1, key2]]
         | 
| 214 | 
            +
                            self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1], self.query_index[key2][0]:self.query_index[key2][1]] = False
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    # build self_attention mask by masking, for birectional
         | 
| 217 | 
            +
                    for key in self.masking:
         | 
| 218 | 
            +
                        if key in self_attn_name:
         | 
| 219 | 
            +
                            self_attn_mask[:,self.query_index[key][0]:self.query_index[key][1],self.query_index[key][0]:self.query_index[key][1]][self.attn_variables[key].masking] = True
         | 
| 220 | 
            +
                            self_attn_mask[:,self.query_index[key][0]:self.query_index[key][1],self.query_index[key][0]:self.query_index[key][1]].transpose(1,2)[self.attn_variables[key].masking] = True
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                    # build self_attention mask by masking, for uni-directional
         | 
| 223 | 
            +
                    for key1, key2 in self_attn_pair:
         | 
| 224 | 
            +
                        if key1 not in self_attn_name or key2 not in self_attn_name:
         | 
| 225 | 
            +
                            # exclude the variables that are not used in the current layer
         | 
| 226 | 
            +
                            continue
         | 
| 227 | 
            +
                        if key1 in self.masking:
         | 
| 228 | 
            +
                            self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1],self.query_index[key2][0]:self.query_index[key2][1]][self.attn_variables[key1].masking] = True # HACK, not verified
         | 
| 229 | 
            +
                        if key2 in self.masking:
         | 
| 230 | 
            +
                            self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1],self.query_index[key2][0]:self.query_index[key2][1]].transpose(1,2)[self.attn_variables[key2].masking] = True
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                    self_attn_mask = self_attn_mask.repeat_interleave(num_heads, dim=0)
         | 
| 233 | 
            +
                    return output, pos_emb, self_attn_mask
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                def update_variables(self, output, mode):
         | 
| 236 | 
            +
                    name_set = self.self_attn_name if mode=='self_attn' else self.cross_attn_name
         | 
| 237 | 
            +
                    for key in name_set:
         | 
| 238 | 
            +
                        self.attn_variables[key].output = output[self.query_index[key][0]:self.query_index[key][1]]
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                def update_spatial_results(self, results):
         | 
| 241 | 
            +
                    v_emb = results['pred_smaskembs']
         | 
| 242 | 
            +
                    pred_smasks = results['pred_smasks']
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                    s_emb = results['pred_pspatials']
         | 
| 245 | 
            +
                    pred_logits = v_emb @ s_emb.transpose(1,2)
         | 
| 246 | 
            +
                    logits_idx_y = pred_logits[:,:,0].max(dim=1)[1]
         | 
| 247 | 
            +
                    logits_idx_x = torch.arange(len(logits_idx_y), device=logits_idx_y.device)
         | 
| 248 | 
            +
                    logits_idx = torch.stack([logits_idx_x, logits_idx_y]).tolist()
         | 
| 249 | 
            +
                    pred_masks_pos = pred_smasks[logits_idx][:,None,]
         | 
| 250 | 
            +
                    
         | 
| 251 | 
            +
                    extra = {"prev_mask": pred_masks_pos}
         | 
| 252 | 
            +
                    return extra
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                def organize_output(self, ):
         | 
| 255 | 
            +
                    outputs = {}
         | 
| 256 | 
            +
                    outputs['aux_outputs'] = [{} for i in range(self.num_layers)]
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                    for key, values in self.output.items():
         | 
| 259 | 
            +
                        for _key, idx_name in zip(predict_name_matcher[key], predict_index_matcher[key]):
         | 
| 260 | 
            +
                            if idx_name not in self.query_index:
         | 
| 261 | 
            +
                                continue
         | 
| 262 | 
            +
                            outputs[_key] = self.output[key][-1][:,self.query_index[idx_name][0]:self.query_index[idx_name][1]]
         | 
| 263 | 
            +
                            for idx, aux_values in enumerate(self.output[key][:-1]):
         | 
| 264 | 
            +
                                outputs['aux_outputs'][idx][_key] = aux_values[:,self.query_index[idx_name][0]:self.query_index[idx_name][1]]
         | 
| 265 | 
            +
                    return outputs
         | 
    	
        modeling/interface/prototype/attention_data_struct_seemv0.py
    ADDED
    
    | @@ -0,0 +1,264 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            predict_name_matcher = {"predictions_class": ["pred_logits"], 
         | 
| 6 | 
            +
                                    "predictions_mask":["pred_masks", "pred_gmasks", "pred_smasks"], 
         | 
| 7 | 
            +
                                    "predictions_caption":["pred_captions", "pred_gtexts"], 
         | 
| 8 | 
            +
                                    "predictions_maskemb":["pred_smaskembs"], 
         | 
| 9 | 
            +
                                    "predictions_pos_spatial":["pred_pspatials"],
         | 
| 10 | 
            +
                                    "predictions_neg_spatial":["pred_nspatials"],}
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            predict_index_matcher = {"predictions_class": ["queries_object"], 
         | 
| 13 | 
            +
                                     "predictions_mask":["queries_object", "queries_grounding", "queries_spatial"], 
         | 
| 14 | 
            +
                                     "predictions_caption": ["queries_object", "queries_grounding"], 
         | 
| 15 | 
            +
                                     "predictions_maskemb":["queries_spatial"], 
         | 
| 16 | 
            +
                                     "predictions_pos_spatial":["all"],
         | 
| 17 | 
            +
                                     "predictions_neg_spatial":["all"],}
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            class Variable(object):
         | 
| 20 | 
            +
                '''
         | 
| 21 | 
            +
                Store dataset variable for attention
         | 
| 22 | 
            +
                output: embedding that accumuates during cross/self attention
         | 
| 23 | 
            +
                pos: positional embedding that is fixed during cross/self attention
         | 
| 24 | 
            +
                name: name of the variable
         | 
| 25 | 
            +
                type: type of the variable, e.g. queries, tokens
         | 
| 26 | 
            +
                attn_mask: attention mask for corss attention
         | 
| 27 | 
            +
                masking: masking for padding
         | 
| 28 | 
            +
                '''
         | 
| 29 | 
            +
                def __init__(self, output, name, _type, pos=None):
         | 
| 30 | 
            +
                    self.output = output
         | 
| 31 | 
            +
                    self.pos = pos
         | 
| 32 | 
            +
                    self.name = name
         | 
| 33 | 
            +
                    self.type = _type
         | 
| 34 | 
            +
                    self.attn_mask = None
         | 
| 35 | 
            +
                    self.masking = None
         | 
| 36 | 
            +
                
         | 
| 37 | 
            +
                def copy(self,):
         | 
| 38 | 
            +
                    output = self.output.clone() if self.output is not None else None
         | 
| 39 | 
            +
                    pos = self.pos.clone() if self.pos is not None else None
         | 
| 40 | 
            +
                    return Variable(output, self.name, self.type, pos)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            class AttentionDataStruct(nn.Module):
         | 
| 43 | 
            +
                '''
         | 
| 44 | 
            +
                Store dataset structure for cross/self attention
         | 
| 45 | 
            +
                task_switch: switch for different tasks
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                p_attn_variables: prototype of variables that is used in cross/self attention
         | 
| 48 | 
            +
                p_self_attn: prototype of variables that is used in self attention
         | 
| 49 | 
            +
                p_cross_attn: prototype of variables that is used in cross attention
         | 
| 50 | 
            +
                p_iter: prototype of iteration for different queries
         | 
| 51 | 
            +
                p_masking: prototype of masking for different tokens
         | 
| 52 | 
            +
                p_duplication: prototype of duplication for different quries
         | 
| 53 | 
            +
                '''
         | 
| 54 | 
            +
                def __init__(self, attn_arch, task_switch):
         | 
| 55 | 
            +
                    super(AttentionDataStruct, self).__init__()
         | 
| 56 | 
            +
                    self.task_switch = task_switch
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    # p stands for prototype
         | 
| 59 | 
            +
                    self.p_attn_variables = attn_arch['VARIABLE']
         | 
| 60 | 
            +
                    self.p_self_attn = attn_arch['SELF_ATTENTION']
         | 
| 61 | 
            +
                    self.p_cross_attn = attn_arch['CROSS_ATTENTION']
         | 
| 62 | 
            +
                    self.p_masking = attn_arch['MASKING']
         | 
| 63 | 
            +
                    self.p_duplication = attn_arch['DUPLICATION']
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    self.num_layers = attn_arch['NUM_LAYERS']
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                def reset(self, flags, task, extra):
         | 
| 68 | 
            +
                    # reset variables
         | 
| 69 | 
            +
                    self.attn_variables = {}
         | 
| 70 | 
            +
                    self.cross_attn_dict = {}
         | 
| 71 | 
            +
                    self.self_attn_dict = {}
         | 
| 72 | 
            +
                    self.duplication_dict = {}
         | 
| 73 | 
            +
                    self.query_index = {}
         | 
| 74 | 
            +
                    self.output = {}
         | 
| 75 | 
            +
                    self.flags = {}
         | 
| 76 | 
            +
                    self.spatial_memory = {}
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    # initialize duplication
         | 
| 79 | 
            +
                    for key, values in self.p_duplication.items():
         | 
| 80 | 
            +
                        for name in values:
         | 
| 81 | 
            +
                            self.duplication_dict["{}_{}".format(key, name)] = self.p_duplication[key][name]
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    # initialize flag
         | 
| 84 | 
            +
                    self.flags = {"object": True}
         | 
| 85 | 
            +
                    self.flags.update(flags)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    # initialize task
         | 
| 88 | 
            +
                    self.task = task
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    # initialize output
         | 
| 91 | 
            +
                    if self.task_switch['mask']:
         | 
| 92 | 
            +
                        self.output['predictions_class'] = []
         | 
| 93 | 
            +
                        self.output['predictions_mask'] = []
         | 
| 94 | 
            +
                    
         | 
| 95 | 
            +
                    if self.task_switch['bbox']:
         | 
| 96 | 
            +
                        self.output['predictions_bbox'] = []
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    if self.task_switch['spatial'] and ('spatial' in self.flags and self.flags['spatial']==True):
         | 
| 99 | 
            +
                        self.output['predictions_maskemb'] = []
         | 
| 100 | 
            +
                        self.output['predictions_pos_spatial'] = []
         | 
| 101 | 
            +
                        self.output['predictions_neg_spatial'] = []
         | 
| 102 | 
            +
                        # self.spatial_memory['spatial_query_mode'] = extra['spatial_query_mode']
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    if self.task_switch['spatial'] and ('memories_spatial' in self.flags and self.flags['memories_spatial']==True):
         | 
| 105 | 
            +
                        self.spatial_memory['prev_batch_mask'] = extra['prev_mask']
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    if self.task_switch['grounding'] and ('grounding' in self.flags and self.flags['grounding']==True):
         | 
| 108 | 
            +
                        self.output['predictions_caption'] = []
         | 
| 109 | 
            +
                    
         | 
| 110 | 
            +
                    # initialize cross_attn, whether the variable is used in cross attention
         | 
| 111 | 
            +
                    for key, values in self.p_cross_attn.items():
         | 
| 112 | 
            +
                        for name in values:
         | 
| 113 | 
            +
                            self.cross_attn_dict["{}_{}".format(key, name)] = self.p_cross_attn[key][name]
         | 
| 114 | 
            +
                    
         | 
| 115 | 
            +
                    # initialize self_attn, whether the variable is used in self attention, and the interactions between queries
         | 
| 116 | 
            +
                    for key, values in self.p_self_attn.items():
         | 
| 117 | 
            +
                        for name in values:
         | 
| 118 | 
            +
                            self.self_attn_dict["{}_{}".format(key, name)] = self.p_self_attn[key][name]
         | 
| 119 | 
            +
                    
         | 
| 120 | 
            +
                    # initialize masking
         | 
| 121 | 
            +
                    self.masking = self.p_masking
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    # initialize query_index
         | 
| 124 | 
            +
                    self.query_index = {"all":[0, None]}
         | 
| 125 | 
            +
             | 
| 126 | 
            +
             | 
| 127 | 
            +
                def set(self, name, _type, output=None, pos=None, var=None):
         | 
| 128 | 
            +
                    if var is not None:
         | 
| 129 | 
            +
                        self.attn_variables[name] = var
         | 
| 130 | 
            +
                    elif name in self.duplication_dict:
         | 
| 131 | 
            +
                        assert self.duplication_dict[name] in self.attn_variables, "Duplication variable {} is not initialized yet.".format(name)
         | 
| 132 | 
            +
                        self.attn_variables[name] = self.attn_variables[self.duplication_dict[name]].copy()
         | 
| 133 | 
            +
                    else:
         | 
| 134 | 
            +
                        var = Variable(output, name, _type, pos)
         | 
| 135 | 
            +
                        self.attn_variables[name] = var
         | 
| 136 | 
            +
                
         | 
| 137 | 
            +
                def set_results(self, results):
         | 
| 138 | 
            +
                    for name in self.cross_attn_name:
         | 
| 139 | 
            +
                        self.attn_variables[name].attn_mask = results['attn_mask'][:,self.query_index[name][0]:self.query_index[name][1]]
         | 
| 140 | 
            +
                    for key in self.output:
         | 
| 141 | 
            +
                        self.output[key].append(results[key])
         | 
| 142 | 
            +
                
         | 
| 143 | 
            +
                def set_maskings(self, name, masking):
         | 
| 144 | 
            +
                    self.attn_variables[name].masking = masking
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                def cross_attn_variables(self, ):
         | 
| 147 | 
            +
                    cross_attn_name = [key for key, value in self.cross_attn_dict.items() 
         | 
| 148 | 
            +
                                       if (value==True) and (key in self.attn_variables) 
         | 
| 149 | 
            +
                                       and ((key not in self.flags) or (key in self.flags and self.flags[key]==True))]
         | 
| 150 | 
            +
                    self.cross_attn_name = cross_attn_name
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    output = torch.cat([self.attn_variables[name].output for name in cross_attn_name])
         | 
| 153 | 
            +
                    pos_emb = torch.cat([self.attn_variables[name].pos for name in cross_attn_name])
         | 
| 154 | 
            +
                    
         | 
| 155 | 
            +
                    index = 0
         | 
| 156 | 
            +
                    for name in cross_attn_name:
         | 
| 157 | 
            +
                        self.query_index[name] = [index, index + self.attn_variables[name].output.shape[0]]
         | 
| 158 | 
            +
                        index += self.attn_variables[name].output.shape[0]
         | 
| 159 | 
            +
                    return output, pos_emb
         | 
| 160 | 
            +
                
         | 
| 161 | 
            +
                def cross_attn_mask(self, size, num_heads):
         | 
| 162 | 
            +
                    attn_mask = torch.cat([self.attn_variables[name].attn_mask for name in self.cross_attn_name], dim=1)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    # hard code memories_spatial to previous selected mask
         | 
| 165 | 
            +
                    if 'memories_spatial' in self.cross_attn_name:
         | 
| 166 | 
            +
                        memory_attn_mask = self.spatial_memory['prev_batch_mask']
         | 
| 167 | 
            +
                        bs,c,_,_ = memory_attn_mask.shape
         | 
| 168 | 
            +
                        memory_attn_mask = F.interpolate(memory_attn_mask, size, mode='bilinear', align_corners=False)
         | 
| 169 | 
            +
                        memory_attn_mask = (memory_attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, num_heads, 1, 1).flatten(0, 1) < 0.5).bool().detach()
         | 
| 170 | 
            +
                        attn_mask[:,self.query_index['memories_spatial'][0]:self.query_index['memories_spatial'][1]] = memory_attn_mask
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
         | 
| 173 | 
            +
                    return attn_mask
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                def self_attn(self, bs, num_heads):
         | 
| 176 | 
            +
                    self_attn_name = [key for key, value in self.self_attn_dict.items() 
         | 
| 177 | 
            +
                                      if len(value)>0 and key in self.attn_variables
         | 
| 178 | 
            +
                                      and ((key not in self.flags) or (key in self.flags and self.flags[key]==True))]
         | 
| 179 | 
            +
                    self.self_attn_name = self_attn_name
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    output = torch.cat([self.attn_variables[name].output for name in self_attn_name])
         | 
| 182 | 
            +
                    pos_emb = torch.cat([self.attn_variables[name].pos for name in self_attn_name])
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    index = 0
         | 
| 185 | 
            +
                    for name in self_attn_name:
         | 
| 186 | 
            +
                        self.query_index[name] = [index, index + self.attn_variables[name].output.shape[0]]
         | 
| 187 | 
            +
                        index += self.attn_variables[name].output.shape[0]
         | 
| 188 | 
            +
                    
         | 
| 189 | 
            +
                    self_attn_mask = torch.ones((bs, output.shape[0], output.shape[0]), dtype=torch.bool, device=output.device)
         | 
| 190 | 
            +
                    self_attn_pair = []
         | 
| 191 | 
            +
                    # build self_attention mask by query interaction
         | 
| 192 | 
            +
                    for key1, value in self.self_attn_dict.items():
         | 
| 193 | 
            +
                        for key2 in value:
         | 
| 194 | 
            +
                            if key1 not in self_attn_name or key2 not in self_attn_name:
         | 
| 195 | 
            +
                                # exclude the variables that are not used in the current layer
         | 
| 196 | 
            +
                                continue
         | 
| 197 | 
            +
                            if (key1 in self.masking or key2 in self.masking) and (key1 != key2):
         | 
| 198 | 
            +
                                self_attn_pair += [[key1, key2]]
         | 
| 199 | 
            +
                            self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1], self.query_index[key2][0]:self.query_index[key2][1]] = False
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    # build self_attention mask by masking, for birectional
         | 
| 202 | 
            +
                    for key in self.masking:
         | 
| 203 | 
            +
                        if key in self_attn_name:
         | 
| 204 | 
            +
                            self_attn_mask[:,self.query_index[key][0]:self.query_index[key][1],self.query_index[key][0]:self.query_index[key][1]][self.attn_variables[key].masking] = True
         | 
| 205 | 
            +
                            self_attn_mask[:,self.query_index[key][0]:self.query_index[key][1],self.query_index[key][0]:self.query_index[key][1]].transpose(1,2)[self.attn_variables[key].masking] = True
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    # build self_attention mask by masking, for uni-directional
         | 
| 208 | 
            +
                    for key1, key2 in self_attn_pair:
         | 
| 209 | 
            +
                        if key1 not in self_attn_name or key2 not in self_attn_name:
         | 
| 210 | 
            +
                            # exclude the variables that are not used in the current layer
         | 
| 211 | 
            +
                            continue
         | 
| 212 | 
            +
                        if key1 in self.masking:
         | 
| 213 | 
            +
                            self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1],self.query_index[key2][0]:self.query_index[key2][1]][self.attn_variables[key1].masking] = True # HACK, not verified
         | 
| 214 | 
            +
                        if key2 in self.masking:
         | 
| 215 | 
            +
                            self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1],self.query_index[key2][0]:self.query_index[key2][1]].transpose(1,2)[self.attn_variables[key2].masking] = True
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                    self_attn_mask = self_attn_mask.repeat_interleave(num_heads, dim=0)
         | 
| 218 | 
            +
                    return output, pos_emb, self_attn_mask
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                def update_variables(self, output, mode):
         | 
| 221 | 
            +
                    name_set = self.self_attn_name if mode=='self_attn' else self.cross_attn_name
         | 
| 222 | 
            +
                    for key in name_set:
         | 
| 223 | 
            +
                        self.attn_variables[key].output = output[self.query_index[key][0]:self.query_index[key][1]]
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                def update_spatial_results(self, results):
         | 
| 226 | 
            +
                    v_emb = results['pred_smaskembs']
         | 
| 227 | 
            +
                    pred_smasks = results['pred_smasks']
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    s_emb = results['pred_pspatials']
         | 
| 230 | 
            +
                    pred_logits = v_emb @ s_emb.transpose(1,2)
         | 
| 231 | 
            +
                    logits_idx_y = pred_logits[:,:,0].max(dim=1)[1]
         | 
| 232 | 
            +
                    logits_idx_x = torch.arange(len(logits_idx_y), device=logits_idx_y.device)
         | 
| 233 | 
            +
                    logits_idx = torch.stack([logits_idx_x, logits_idx_y]).tolist()
         | 
| 234 | 
            +
                    pred_masks_pos = pred_smasks[logits_idx][:,None,]
         | 
| 235 | 
            +
                    
         | 
| 236 | 
            +
                    # s_emb = results['pred_nspatials']
         | 
| 237 | 
            +
                    # pred_logits = v_emb @ s_emb.transpose(1,2)
         | 
| 238 | 
            +
                    # logits_idx_y = pred_logits[:,:,0].max(dim=1)[1]
         | 
| 239 | 
            +
                    # logits_idx_x = torch.arange(len(logits_idx_y), device=logits_idx_y.device)
         | 
| 240 | 
            +
                    # logits_idx = torch.stack([logits_idx_x, logits_idx_y]).tolist()
         | 
| 241 | 
            +
                    # pred_masks_neg = pred_smasks[logits_idx][:,None,]
         | 
| 242 | 
            +
                    # # clip the negative mask to 0, and then multiply by -1
         | 
| 243 | 
            +
                    # pred_masks_neg = (pred_masks_neg.clip(0) * -1)
         | 
| 244 | 
            +
                    # keep_neg = (s_emb.sum(dim=list(range(1, s_emb.dim()))) != 0).float()
         | 
| 245 | 
            +
                    # pred_masks_neg = pred_masks_neg * keep_neg[:,None,None,None]
         | 
| 246 | 
            +
                    # extra = {"prev_mask": pred_masks_pos + pred_masks_neg}
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    extra = {"prev_mask": pred_masks_pos}
         | 
| 249 | 
            +
                    return extra
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                def organize_output(self, ):
         | 
| 252 | 
            +
                    outputs = {}
         | 
| 253 | 
            +
                    outputs['aux_outputs'] = [{} for i in range(self.num_layers)]
         | 
| 254 | 
            +
                    for key, values in self.output.items():
         | 
| 255 | 
            +
                        for _key, idx_name in zip(predict_name_matcher[key], predict_index_matcher[key]):
         | 
| 256 | 
            +
                            if idx_name not in self.query_index:
         | 
| 257 | 
            +
                                continue
         | 
| 258 | 
            +
                            outputs[_key] = self.output[key][-1][:,self.query_index[idx_name][0]:self.query_index[idx_name][1]]
         | 
| 259 | 
            +
                            for idx, aux_values in enumerate(self.output[key][:-1]):
         | 
| 260 | 
            +
                                outputs['aux_outputs'][idx][_key] = aux_values[:,self.query_index[idx_name][0]:self.query_index[idx_name][1]]
         | 
| 261 | 
            +
                    if self.task == 'spatial' or self.task == 'refimg':
         | 
| 262 | 
            +
                        outputs = self.update_spatial_results(outputs)
         | 
| 263 | 
            +
                    # outputs = self.update_spatial_results(outputs)
         | 
| 264 | 
            +
                    return outputs
         | 
    	
        modeling/interface/prototype/attention_data_struct_seemv1.py
    ADDED
    
    | @@ -0,0 +1,302 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            predict_name_matcher = {"predictions_class": ["pred_logits"],
         | 
| 6 | 
            +
                                    "predictions_mask":["pred_masks", "pred_gmasks", "pred_smasks"],
         | 
| 7 | 
            +
                                    "predictions_caption":["pred_captions", "pred_gtexts", "pred_stexts"],
         | 
| 8 | 
            +
                                    "predictions_maskemb":["pred_smaskembs"],
         | 
| 9 | 
            +
                                    "predictions_pos_spatial":["pred_pspatials"],
         | 
| 10 | 
            +
                                    "predictions_neg_spatial":["pred_nspatials"],}
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            predict_index_matcher = {"predictions_class": ["queries_object"],
         | 
| 13 | 
            +
                                     "predictions_mask":["queries_object", "queries_grounding", "queries_spatial"],
         | 
| 14 | 
            +
                                     "predictions_caption": ["queries_object", "queries_grounding", "queries_spatial"],
         | 
| 15 | 
            +
                                     "predictions_maskemb":["queries_spatial"],
         | 
| 16 | 
            +
                                     "predictions_pos_spatial":["all"],
         | 
| 17 | 
            +
                                     "predictions_neg_spatial":["all"],}
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            class Variable(object):
         | 
| 20 | 
            +
                '''
         | 
| 21 | 
            +
                Store dataset variable for attention
         | 
| 22 | 
            +
                output: embedding that accumuates during cross/self attention
         | 
| 23 | 
            +
                pos: positional embedding that is fixed during cross/self attention
         | 
| 24 | 
            +
                name: name of the variable
         | 
| 25 | 
            +
                type: type of the variable, e.g. queries, tokens
         | 
| 26 | 
            +
                attn_mask: attention mask for corss attention
         | 
| 27 | 
            +
                masking: masking for padding
         | 
| 28 | 
            +
                '''
         | 
| 29 | 
            +
                def __init__(self, output, name, _type, pos=None):
         | 
| 30 | 
            +
                    self.output = output
         | 
| 31 | 
            +
                    self.pos = pos
         | 
| 32 | 
            +
                    self.name = name
         | 
| 33 | 
            +
                    self.type = _type
         | 
| 34 | 
            +
                    self.attn_mask = None
         | 
| 35 | 
            +
                    self.masking = None
         | 
| 36 | 
            +
                
         | 
| 37 | 
            +
                def copy(self,):
         | 
| 38 | 
            +
                    output = self.output.clone() if self.output is not None else None
         | 
| 39 | 
            +
                    pos = self.pos.clone() if self.pos is not None else None
         | 
| 40 | 
            +
                    return Variable(output, self.name, self.type, pos)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                def rand_sample(self, max_len):
         | 
| 43 | 
            +
                    rand_idx = torch.randint(0, len(self.pos), (max_len,))
         | 
| 44 | 
            +
                    self.output = self.output[rand_idx]
         | 
| 45 | 
            +
                    self.pos = self.pos[rand_idx]
         | 
| 46 | 
            +
                    return self
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            class AttentionDataStruct(nn.Module):
         | 
| 49 | 
            +
                '''
         | 
| 50 | 
            +
                Store dataset structure for cross/self attention
         | 
| 51 | 
            +
                task_switch: switch for different tasks
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                p_attn_variables: prototype of variables that is used in cross/self attention
         | 
| 54 | 
            +
                p_self_attn: prototype of variables that is used in self attention
         | 
| 55 | 
            +
                p_cross_attn: prototype of variables that is used in cross attention
         | 
| 56 | 
            +
                p_iter: prototype of iteration for different queries
         | 
| 57 | 
            +
                p_masking: prototype of masking for different tokens
         | 
| 58 | 
            +
                p_duplication: prototype of duplication for different quries
         | 
| 59 | 
            +
                '''
         | 
| 60 | 
            +
                def __init__(self, attn_arch, task_switch):
         | 
| 61 | 
            +
                    super(AttentionDataStruct, self).__init__()
         | 
| 62 | 
            +
                    self.task_switch = task_switch
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    # p stands for prototype
         | 
| 65 | 
            +
                    self.p_attn_variables = attn_arch['VARIABLE']
         | 
| 66 | 
            +
                    self.p_self_attn = attn_arch['SELF_ATTENTION']
         | 
| 67 | 
            +
                    self.p_cross_attn = attn_arch['CROSS_ATTENTION']
         | 
| 68 | 
            +
                    self.p_masking = attn_arch['MASKING']
         | 
| 69 | 
            +
                    self.p_duplication = attn_arch['DUPLICATION']
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    self.num_layers = attn_arch['NUM_LAYERS']
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def reset(self, flags, task, extra):
         | 
| 74 | 
            +
                    # reset variables
         | 
| 75 | 
            +
                    self.attn_variables = {}
         | 
| 76 | 
            +
                    self.cross_attn_dict = {}
         | 
| 77 | 
            +
                    self.self_attn_dict = {}
         | 
| 78 | 
            +
                    self.duplication_dict = {}
         | 
| 79 | 
            +
                    self.query_index = {}
         | 
| 80 | 
            +
                    self.output = {}
         | 
| 81 | 
            +
                    self.flags = {}
         | 
| 82 | 
            +
                    self.spatial_memory = {}
         | 
| 83 | 
            +
                    self.extra = {}
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    # initialize duplication
         | 
| 86 | 
            +
                    for key, values in self.p_duplication.items():
         | 
| 87 | 
            +
                        for name in values:
         | 
| 88 | 
            +
                            self.duplication_dict["{}_{}".format(key, name)] = self.p_duplication[key][name]
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    # initialize flag
         | 
| 91 | 
            +
                    self.flags = {"object": True}
         | 
| 92 | 
            +
                    self.flags.update(flags)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    # initialize task
         | 
| 95 | 
            +
                    self.task = task
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    # initialize output
         | 
| 98 | 
            +
                    if self.task_switch['mask']:
         | 
| 99 | 
            +
                        self.output['predictions_class'] = []
         | 
| 100 | 
            +
                        self.output['predictions_mask'] = []
         | 
| 101 | 
            +
                    
         | 
| 102 | 
            +
                    if self.task_switch['bbox']:
         | 
| 103 | 
            +
                        self.output['predictions_bbox'] = []
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    if self.task_switch['spatial'] and ('memories_spatial' in self.flags and self.flags['memories_spatial']==True):
         | 
| 106 | 
            +
                        self.spatial_memory['prev_batch_mask'] = extra['prev_mask']
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    if self.task_switch['grounding'] and ('grounding' in self.flags and self.flags['grounding']==True):
         | 
| 109 | 
            +
                        self.output['predictions_caption'] = []
         | 
| 110 | 
            +
                    
         | 
| 111 | 
            +
                    if self.task_switch['spatial'] and ('spatial' in self.flags and self.flags['spatial']==True):
         | 
| 112 | 
            +
                        self.output['predictions_maskemb'] = []
         | 
| 113 | 
            +
                        self.output['predictions_pos_spatial'] = []
         | 
| 114 | 
            +
                        self.output['predictions_neg_spatial'] = []
         | 
| 115 | 
            +
                        self.output['predictions_mask'] = [] if 'predictions_mask' not in self.output else self.output['predictions_mask']
         | 
| 116 | 
            +
                        self.output['predictions_class'] = [] if 'predictions_class' not in self.output else self.output['predictions_class']
         | 
| 117 | 
            +
                        self.output['predictions_caption'] = [] if 'predictions_caption' not in self.output else self.output['predictions_caption']
         | 
| 118 | 
            +
                    
         | 
| 119 | 
            +
                    # initialize cross_attn, whether the variable is used in cross attention
         | 
| 120 | 
            +
                    for key, values in self.p_cross_attn.items():
         | 
| 121 | 
            +
                        for name in values:
         | 
| 122 | 
            +
                            self.cross_attn_dict["{}_{}".format(key, name)] = self.p_cross_attn[key][name]
         | 
| 123 | 
            +
                    
         | 
| 124 | 
            +
                    # initialize self_attn, whether the variable is used in self attention, and the interactions between queries
         | 
| 125 | 
            +
                    for key, values in self.p_self_attn.items():
         | 
| 126 | 
            +
                        for name in values:
         | 
| 127 | 
            +
                            self.self_attn_dict["{}_{}".format(key, name)] = self.p_self_attn[key][name]
         | 
| 128 | 
            +
                    
         | 
| 129 | 
            +
                    # initialize masking
         | 
| 130 | 
            +
                    self.masking = self.p_masking
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    # initialize query_index
         | 
| 133 | 
            +
                    self.query_index = {"all":[0, None]}
         | 
| 134 | 
            +
             | 
| 135 | 
            +
             | 
| 136 | 
            +
                def set(self, name, _type, output=None, pos=None, var=None, sample_size=None):
         | 
| 137 | 
            +
                    if var is not None:
         | 
| 138 | 
            +
                        self.attn_variables[name] = var
         | 
| 139 | 
            +
                    elif name in self.duplication_dict:
         | 
| 140 | 
            +
                        assert self.duplication_dict[name] in self.attn_variables, "Duplication variable {} is not initialized yet.".format(name)
         | 
| 141 | 
            +
                        var = self.attn_variables[self.duplication_dict[name]].copy()
         | 
| 142 | 
            +
                        if sample_size is not None:
         | 
| 143 | 
            +
                            var = var.rand_sample(sample_size)
         | 
| 144 | 
            +
                        self.attn_variables[name] = var
         | 
| 145 | 
            +
                    else:
         | 
| 146 | 
            +
                        var = Variable(output, name, _type, pos)
         | 
| 147 | 
            +
                        self.attn_variables[name] = var
         | 
| 148 | 
            +
                
         | 
| 149 | 
            +
                def set_results(self, results):
         | 
| 150 | 
            +
                    for name in self.cross_attn_name:
         | 
| 151 | 
            +
                        self.attn_variables[name].attn_mask = results['attn_mask'][:,self.query_index[name][0]:self.query_index[name][1]]
         | 
| 152 | 
            +
                    for key in self.output:
         | 
| 153 | 
            +
                        self.output[key].append(results[key])
         | 
| 154 | 
            +
                
         | 
| 155 | 
            +
                def set_maskings(self, name, masking):
         | 
| 156 | 
            +
                    self.attn_variables[name].masking = masking
         | 
| 157 | 
            +
                
         | 
| 158 | 
            +
                def set_extra(self, extra):
         | 
| 159 | 
            +
                    self.extra.update(extra)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                def cross_attn_variables(self, ):
         | 
| 162 | 
            +
                    cross_attn_name = [key for key, value in self.cross_attn_dict.items() 
         | 
| 163 | 
            +
                                       if (value==True) and (key in self.attn_variables) 
         | 
| 164 | 
            +
                                       and ((key not in self.flags) or (key in self.flags and self.flags[key]==True))]
         | 
| 165 | 
            +
                    self.cross_attn_name = cross_attn_name
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                    output = torch.cat([self.attn_variables[name].output for name in cross_attn_name])
         | 
| 168 | 
            +
                    pos_emb = torch.cat([self.attn_variables[name].pos for name in cross_attn_name])
         | 
| 169 | 
            +
                    
         | 
| 170 | 
            +
                    index = 0
         | 
| 171 | 
            +
                    for name in cross_attn_name:
         | 
| 172 | 
            +
                        self.query_index[name] = [index, index + self.attn_variables[name].output.shape[0]]
         | 
| 173 | 
            +
                        index += self.attn_variables[name].output.shape[0]
         | 
| 174 | 
            +
                    return output, pos_emb
         | 
| 175 | 
            +
                
         | 
| 176 | 
            +
                def cross_attn_mask(self, size, num_heads):
         | 
| 177 | 
            +
                    attn_mask = torch.cat([self.attn_variables[name].attn_mask for name in self.cross_attn_name], dim=1)
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    # hard code memories_spatial to previous selected mask
         | 
| 180 | 
            +
                    if 'memories_spatial' in self.cross_attn_name:
         | 
| 181 | 
            +
                        memory_attn_mask = self.spatial_memory['prev_batch_mask']
         | 
| 182 | 
            +
                        bs,c,_,_ = memory_attn_mask.shape
         | 
| 183 | 
            +
                        memory_attn_mask = F.interpolate(memory_attn_mask, size, mode='bilinear', align_corners=False)
         | 
| 184 | 
            +
                        memory_attn_mask = (memory_attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, num_heads, 1, 1).flatten(0, 1) < 0.5).bool().detach()
         | 
| 185 | 
            +
                        repeat = (self.query_index['memories_spatial'][1] - self.query_index['memories_spatial'][0]) // c
         | 
| 186 | 
            +
                        mem_len = self.query_index['memories_spatial'][1] - self.query_index['memories_spatial'][0]
         | 
| 187 | 
            +
                        probs = torch.tensor([1./repeat for i in range(c)])
         | 
| 188 | 
            +
                        indices = torch.multinomial(probs, num_samples=mem_len, replacement=True).sort()[0]
         | 
| 189 | 
            +
                        attn_mask[:,self.query_index['memories_spatial'][0]:self.query_index['memories_spatial'][1]] = memory_attn_mask[:,indices]
         | 
| 190 | 
            +
                        self.extra['memory_indices'] = indices
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
         | 
| 193 | 
            +
                    return attn_mask
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                def self_attn(self, bs, num_heads):
         | 
| 196 | 
            +
                    self_attn_name = [key for key, value in self.self_attn_dict.items() 
         | 
| 197 | 
            +
                                      if len(value)>0 and key in self.attn_variables
         | 
| 198 | 
            +
                                      and ((key not in self.flags) or (key in self.flags and self.flags[key]==True))]
         | 
| 199 | 
            +
                    self.self_attn_name = self_attn_name
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    output = torch.cat([self.attn_variables[name].output for name in self_attn_name])
         | 
| 202 | 
            +
                    pos_emb = torch.cat([self.attn_variables[name].pos for name in self_attn_name])
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                    index = 0
         | 
| 205 | 
            +
                    for name in self_attn_name:
         | 
| 206 | 
            +
                        self.query_index[name] = [index, index + self.attn_variables[name].output.shape[0]]
         | 
| 207 | 
            +
                        index += self.attn_variables[name].output.shape[0]
         | 
| 208 | 
            +
                    
         | 
| 209 | 
            +
                    self_attn_mask = torch.ones((bs, output.shape[0], output.shape[0]), dtype=torch.bool, device=output.device)
         | 
| 210 | 
            +
                    self_attn_pair = []
         | 
| 211 | 
            +
                    # build self_attention mask by query interaction
         | 
| 212 | 
            +
                    for key1, value in self.self_attn_dict.items():
         | 
| 213 | 
            +
                        for key2 in value:
         | 
| 214 | 
            +
                            if key1 not in self_attn_name or key2 not in self_attn_name:
         | 
| 215 | 
            +
                                # exclude the variables that are not used in the current layer
         | 
| 216 | 
            +
                                continue
         | 
| 217 | 
            +
                            if (key1 in self.masking or key2 in self.masking) and (key1 != key2):
         | 
| 218 | 
            +
                                self_attn_pair += [[key1, key2]]
         | 
| 219 | 
            +
                            self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1], self.query_index[key2][0]:self.query_index[key2][1]] = False
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                    # build self_attention mask by masking, for birectional
         | 
| 222 | 
            +
                    for key in self.masking:
         | 
| 223 | 
            +
                        if key in self_attn_name:
         | 
| 224 | 
            +
                            self_attn_mask[:,self.query_index[key][0]:self.query_index[key][1],self.query_index[key][0]:self.query_index[key][1]][self.attn_variables[key].masking] = True
         | 
| 225 | 
            +
                            self_attn_mask[:,self.query_index[key][0]:self.query_index[key][1],self.query_index[key][0]:self.query_index[key][1]].transpose(1,2)[self.attn_variables[key].masking] = True
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    # build self_attention mask by masking, for uni-directional
         | 
| 228 | 
            +
                    for key1, key2 in self_attn_pair:
         | 
| 229 | 
            +
                        if key1 not in self_attn_name or key2 not in self_attn_name:
         | 
| 230 | 
            +
                            # exclude the variables that are not used in the current layer
         | 
| 231 | 
            +
                            continue
         | 
| 232 | 
            +
                        if key1 in self.masking:
         | 
| 233 | 
            +
                            self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1],self.query_index[key2][0]:self.query_index[key2][1]][self.attn_variables[key1].masking] = True # HACK, not verified
         | 
| 234 | 
            +
                        if key2 in self.masking:
         | 
| 235 | 
            +
                            self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1],self.query_index[key2][0]:self.query_index[key2][1]].transpose(1,2)[self.attn_variables[key2].masking] = True
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    # build self_attention mask masking for spatial query
         | 
| 238 | 
            +
                    # spatial query attend with itself
         | 
| 239 | 
            +
                    if 'queries_spatial' in self_attn_name and 'tokens_spatial' in self_attn_name:
         | 
| 240 | 
            +
                        diag_mask = ~(torch.eye(self.extra['spatial_query_number']).repeat_interleave(self.extra['sample_size'],dim=0).repeat_interleave(self.extra['sample_size'],dim=1)).bool()
         | 
| 241 | 
            +
                        self_attn_mask[:,self.query_index['queries_spatial'][0]:self.query_index['queries_spatial'][1],self.query_index['queries_spatial'][0]:self.query_index['queries_spatial'][1]] = diag_mask[None,]
         | 
| 242 | 
            +
                        # spatial query attend with spatial token
         | 
| 243 | 
            +
                        indices = self.extra['spatial_indices'].permute(0,2,1)
         | 
| 244 | 
            +
                        diag_index = torch.arange(self.extra['spatial_query_number'], device=indices.device).repeat_interleave(self.extra['sample_size'],dim=0)[None,:,None]
         | 
| 245 | 
            +
                        diag_mask = ~(indices == diag_index)
         | 
| 246 | 
            +
                        self_attn_mask[:,self.query_index['queries_spatial'][0]:self.query_index['queries_spatial'][1],self.query_index['tokens_spatial'][0]:self.query_index['tokens_spatial'][1]] = diag_mask
         | 
| 247 | 
            +
                        # spatial token attend with itself
         | 
| 248 | 
            +
                        diag_mask = ~(indices == indices.transpose(1,2))
         | 
| 249 | 
            +
                        self_attn_mask[:,self.query_index['tokens_spatial'][0]:self.query_index['tokens_spatial'][1],self.query_index['tokens_spatial'][0]:self.query_index['tokens_spatial'][1]] = diag_mask
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                    if 'memory_indices' in self.extra:
         | 
| 252 | 
            +
                        # spatial query attend with memory
         | 
| 253 | 
            +
                        memory_indices = self.extra['memory_indices'][None,None,:]
         | 
| 254 | 
            +
                        diag_index = torch.arange(self.extra['spatial_query_number'], device=memory_indices.device).repeat_interleave(self.extra['sample_size'],dim=0)[None,:,None]
         | 
| 255 | 
            +
                        diag_mask = ~(diag_index == memory_indices)
         | 
| 256 | 
            +
                        self_attn_mask[:,self.query_index['queries_spatial'][0]:self.query_index['queries_spatial'][1],self.query_index['memories_spatial'][0]:self.query_index['memories_spatial'][1]] = diag_mask
         | 
| 257 | 
            +
                        # memory attend with itself
         | 
| 258 | 
            +
                        diag_mask = ~(memory_indices == memory_indices.transpose(1,2))
         | 
| 259 | 
            +
                        self_attn_mask[:,self.query_index['memories_spatial'][0]:self.query_index['memories_spatial'][1],self.query_index['memories_spatial'][0]:self.query_index['memories_spatial'][1]] = diag_mask
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                    self_attn_mask = self_attn_mask.repeat_interleave(num_heads, dim=0)
         | 
| 262 | 
            +
                    return output, pos_emb, self_attn_mask
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                def update_variables(self, output, mode):
         | 
| 265 | 
            +
                    name_set = self.self_attn_name if mode=='self_attn' else self.cross_attn_name
         | 
| 266 | 
            +
                    for key in name_set:
         | 
| 267 | 
            +
                        self.attn_variables[key].output = output[self.query_index[key][0]:self.query_index[key][1]]
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                def update_spatial_results(self, results):
         | 
| 270 | 
            +
                    v_emb = results['pred_smaskembs']
         | 
| 271 | 
            +
                    pred_smasks = results['pred_smasks']
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    s_emb = results['pred_pspatials']
         | 
| 274 | 
            +
                    diag_mask = ~(torch.eye(self.extra['spatial_query_number'], device=s_emb.device).repeat_interleave(self.extra['sample_size'],dim=0)).bool()
         | 
| 275 | 
            +
                    offset = torch.zeros_like(diag_mask, device=s_emb.device).float()
         | 
| 276 | 
            +
                    offset.masked_fill_(diag_mask, float("-inf"))
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                    pred_logits = v_emb @ s_emb.transpose(1,2) + offset[None,]
         | 
| 279 | 
            +
                    bs,_,ns=pred_logits.shape
         | 
| 280 | 
            +
                    _,_,h,w=pred_smasks.shape        
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    logits_idx_y = pred_logits.max(dim=1)[1]
         | 
| 283 | 
            +
                    logits_idx_x = torch.arange(len(logits_idx_y), device=logits_idx_y.device)[:,None].repeat(1, logits_idx_y.shape[1])
         | 
| 284 | 
            +
                    logits_idx = torch.stack([logits_idx_x, logits_idx_y]).view(2,-1).tolist()
         | 
| 285 | 
            +
                    pred_masks_pos = pred_smasks[logits_idx].reshape(bs,ns,h,w)
         | 
| 286 | 
            +
                    extra = {"prev_mask": pred_masks_pos}
         | 
| 287 | 
            +
                    return extra
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                def organize_output(self, ):
         | 
| 290 | 
            +
                    outputs = {}
         | 
| 291 | 
            +
                    outputs['aux_outputs'] = [{} for i in range(self.num_layers)]
         | 
| 292 | 
            +
                    for key, values in self.output.items():
         | 
| 293 | 
            +
                        for _key, idx_name in zip(predict_name_matcher[key], predict_index_matcher[key]):
         | 
| 294 | 
            +
                            if idx_name not in self.query_index:
         | 
| 295 | 
            +
                                continue
         | 
| 296 | 
            +
                            outputs[_key] = self.output[key][-1][:,self.query_index[idx_name][0]:self.query_index[idx_name][1]]
         | 
| 297 | 
            +
                            for idx, aux_values in enumerate(self.output[key][:-1]):
         | 
| 298 | 
            +
                                outputs['aux_outputs'][idx][_key] = aux_values[:,self.query_index[idx_name][0]:self.query_index[idx_name][1]]
         | 
| 299 | 
            +
                    if self.task == 'spatial' or self.task == 'refimg':
         | 
| 300 | 
            +
                        outputs = self.update_spatial_results(outputs)
         | 
| 301 | 
            +
                    # outputs = self.update_spatial_results(outputs)
         | 
| 302 | 
            +
                    return outputs
         | 
    	
        modeling/interface/seem_demo.py
    ADDED
    
    | @@ -0,0 +1,397 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # --------------------------------------------------------
         | 
| 2 | 
            +
            # SEEM -- Segment Everything Everywhere All At Once
         | 
| 3 | 
            +
            # Licensed under The Apache License 2.0 [see LICENSE for details]
         | 
| 4 | 
            +
            # Written by Xueyan Zou (xueyan@cs.wisc.edu), Jianwei Yang (jianwyan@microsoft.com)
         | 
| 5 | 
            +
            # --------------------------------------------------------
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import logging
         | 
| 8 | 
            +
            from typing import Optional
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            from torch import nn, Tensor
         | 
| 12 | 
            +
            from torch.nn import functional as F
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from timm.models.layers import trunc_normal_
         | 
| 15 | 
            +
            from detectron2.layers import Conv2d
         | 
| 16 | 
            +
            import fvcore.nn.weight_init as weight_init
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from .build import register_decoder
         | 
| 19 | 
            +
            from .modules import SelfAttentionLayer, CrossAttentionLayer, FFNLayer, MLP
         | 
| 20 | 
            +
            from .prototype.attention_data_struct_seemdemo import AttentionDataStruct
         | 
| 21 | 
            +
            from ..utils import rand_sample_plain as rand_sample
         | 
| 22 | 
            +
            from ..utils import prepare_features, configurable
         | 
| 23 | 
            +
            from ..modules import PositionEmbeddingSine
         | 
| 24 | 
            +
            from ..modules.point_features import point_sample
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            class SEEMDecoder(nn.Module):
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                @configurable
         | 
| 30 | 
            +
                def __init__(
         | 
| 31 | 
            +
                    self,
         | 
| 32 | 
            +
                    lang_encoder: nn.Module,
         | 
| 33 | 
            +
                    in_channels,
         | 
| 34 | 
            +
                    mask_classification=True,
         | 
| 35 | 
            +
                    *,
         | 
| 36 | 
            +
                    hidden_dim: int,
         | 
| 37 | 
            +
                    dim_proj: int,
         | 
| 38 | 
            +
                    num_queries: int,
         | 
| 39 | 
            +
                    contxt_len: int,
         | 
| 40 | 
            +
                    nheads: int,
         | 
| 41 | 
            +
                    dim_feedforward: int,
         | 
| 42 | 
            +
                    dec_layers: int,
         | 
| 43 | 
            +
                    pre_norm: bool,
         | 
| 44 | 
            +
                    mask_dim: int,
         | 
| 45 | 
            +
                    task_switch: dict,
         | 
| 46 | 
            +
                    enforce_input_project: bool,
         | 
| 47 | 
            +
                    max_spatial_len: int,
         | 
| 48 | 
            +
                    attn_arch: dict,
         | 
| 49 | 
            +
                ):
         | 
| 50 | 
            +
                    """
         | 
| 51 | 
            +
                    NOTE: this interface is experimental.
         | 
| 52 | 
            +
                    Args:
         | 
| 53 | 
            +
                        in_channels: channels of the input features
         | 
| 54 | 
            +
                        mask_classification: whether to add mask classifier or not
         | 
| 55 | 
            +
                        num_classes: number of classes
         | 
| 56 | 
            +
                        hidden_dim: Transformer feature dimension
         | 
| 57 | 
            +
                        num_queries: number of queries
         | 
| 58 | 
            +
                        nheads: number of heads
         | 
| 59 | 
            +
                        dim_feedforward: feature dimension in feedforward network
         | 
| 60 | 
            +
                        enc_layers: number of Transformer encoder layers
         | 
| 61 | 
            +
                        dec_layers: number of Transformer decoder layers
         | 
| 62 | 
            +
                        pre_norm: whether to use pre-LayerNorm or not
         | 
| 63 | 
            +
                        mask_dim: mask feature dimension
         | 
| 64 | 
            +
                        enforce_input_project: add input project 1x1 conv even if input
         | 
| 65 | 
            +
                            channels and hidden dim is identical
         | 
| 66 | 
            +
                    """
         | 
| 67 | 
            +
                    super().__init__()
         | 
| 68 | 
            +
                    assert mask_classification, "Only support mask classification model"
         | 
| 69 | 
            +
                    self.mask_classification = mask_classification
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    # positional encoding
         | 
| 72 | 
            +
                    N_steps = hidden_dim // 2
         | 
| 73 | 
            +
                    self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
         | 
| 74 | 
            +
                    
         | 
| 75 | 
            +
                    # define Transformer decoder here
         | 
| 76 | 
            +
                    self.num_heads = nheads
         | 
| 77 | 
            +
                    self.num_layers = dec_layers
         | 
| 78 | 
            +
                    self.contxt_len = contxt_len
         | 
| 79 | 
            +
                    self.transformer_self_attention_layers = nn.ModuleList()
         | 
| 80 | 
            +
                    self.transformer_cross_attention_layers = nn.ModuleList()
         | 
| 81 | 
            +
                    self.transformer_ffn_layers = nn.ModuleList()
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    for _ in range(self.num_layers):
         | 
| 84 | 
            +
                        self.transformer_self_attention_layers.append(
         | 
| 85 | 
            +
                            SelfAttentionLayer(
         | 
| 86 | 
            +
                                d_model=hidden_dim,
         | 
| 87 | 
            +
                                nhead=nheads,
         | 
| 88 | 
            +
                                dropout=0.0,
         | 
| 89 | 
            +
                                normalize_before=pre_norm,
         | 
| 90 | 
            +
                            )
         | 
| 91 | 
            +
                        )
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                        self.transformer_cross_attention_layers.append(
         | 
| 94 | 
            +
                            CrossAttentionLayer(
         | 
| 95 | 
            +
                                d_model=hidden_dim,
         | 
| 96 | 
            +
                                nhead=nheads,
         | 
| 97 | 
            +
                                dropout=0.0,
         | 
| 98 | 
            +
                                normalize_before=pre_norm,
         | 
| 99 | 
            +
                            )
         | 
| 100 | 
            +
                        )
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                        self.transformer_ffn_layers.append(
         | 
| 103 | 
            +
                            FFNLayer(
         | 
| 104 | 
            +
                                d_model=hidden_dim,
         | 
| 105 | 
            +
                                dim_feedforward=dim_feedforward,
         | 
| 106 | 
            +
                                dropout=0.0,
         | 
| 107 | 
            +
                                normalize_before=pre_norm,
         | 
| 108 | 
            +
                            )
         | 
| 109 | 
            +
                        )
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    self.decoder_norm = nn.LayerNorm(hidden_dim)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    self.num_queries = num_queries
         | 
| 114 | 
            +
                    # learnable query features
         | 
| 115 | 
            +
                    self.query_feat = nn.Embedding(num_queries, hidden_dim)
         | 
| 116 | 
            +
                    # learnable query p.e.
         | 
| 117 | 
            +
                    self.query_embed = nn.Embedding(num_queries, hidden_dim)
         | 
| 118 | 
            +
                    # learnable positive negative indicator
         | 
| 119 | 
            +
                    self.pn_indicator = nn.Embedding(2, hidden_dim)
         | 
| 120 | 
            +
                    
         | 
| 121 | 
            +
                    # level embedding (we always use 3 scales)
         | 
| 122 | 
            +
                    self.num_feature_levels = 3
         | 
| 123 | 
            +
                    self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
         | 
| 124 | 
            +
                    self.input_proj = nn.ModuleList()
         | 
| 125 | 
            +
                    
         | 
| 126 | 
            +
                    for _ in range(self.num_feature_levels):
         | 
| 127 | 
            +
                        if in_channels != hidden_dim or enforce_input_project:
         | 
| 128 | 
            +
                            self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
         | 
| 129 | 
            +
                            weight_init.c2_xavier_fill(self.input_proj[-1])
         | 
| 130 | 
            +
                        else:
         | 
| 131 | 
            +
                            self.input_proj.append(nn.Sequential())
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    self.task_switch = task_switch
         | 
| 134 | 
            +
                    self.query_index = {}
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    # output FFNs
         | 
| 137 | 
            +
                    self.lang_encoder = lang_encoder
         | 
| 138 | 
            +
                    if self.task_switch['mask']:
         | 
| 139 | 
            +
                        self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
         | 
| 142 | 
            +
                    trunc_normal_(self.class_embed, std=.02)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    if task_switch['bbox']:
         | 
| 145 | 
            +
                        self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    if task_switch['spatial']:
         | 
| 148 | 
            +
                        # spatial query
         | 
| 149 | 
            +
                        self.mask_sptial_embed = nn.ParameterList([nn.Parameter(torch.empty(hidden_dim, hidden_dim)) for x in range(3)])
         | 
| 150 | 
            +
                        trunc_normal_(self.mask_sptial_embed[0], std=.02)
         | 
| 151 | 
            +
                        trunc_normal_(self.mask_sptial_embed[1], std=.02)
         | 
| 152 | 
            +
                        trunc_normal_(self.mask_sptial_embed[2], std=.02)
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                        self.max_spatial_len = max_spatial_len
         | 
| 155 | 
            +
                        # spatial memory
         | 
| 156 | 
            +
                        num_spatial_memories = attn_arch['SPATIAL_MEMORIES']
         | 
| 157 | 
            +
                        self.spatial_embed = nn.Embedding(num_spatial_memories, hidden_dim)
         | 
| 158 | 
            +
                        self.spatial_featured = nn.Embedding(num_spatial_memories, hidden_dim)
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    # build AttentionDataStruct
         | 
| 161 | 
            +
                    attn_arch['NUM_LAYERS'] = self.num_layers
         | 
| 162 | 
            +
                    self.attention_data = AttentionDataStruct(attn_arch, task_switch)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                @classmethod
         | 
| 165 | 
            +
                def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra):
         | 
| 166 | 
            +
                    ret = {}
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    ret["lang_encoder"] = lang_encoder
         | 
| 169 | 
            +
                    ret["in_channels"] = in_channels
         | 
| 170 | 
            +
                    ret["mask_classification"] = mask_classification
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    enc_cfg = cfg['MODEL']['ENCODER']
         | 
| 173 | 
            +
                    dec_cfg = cfg['MODEL']['DECODER']
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    ret["hidden_dim"] = dec_cfg['HIDDEN_DIM']
         | 
| 176 | 
            +
                    ret["dim_proj"] = cfg['MODEL']['DIM_PROJ']
         | 
| 177 | 
            +
                    ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES']
         | 
| 178 | 
            +
                    ret["contxt_len"] = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                    # Transformer parameters:
         | 
| 181 | 
            +
                    ret["nheads"] = dec_cfg['NHEADS']
         | 
| 182 | 
            +
                    ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD']
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    # NOTE: because we add learnable query features which requires supervision,
         | 
| 185 | 
            +
                    # we add minus 1 to decoder layers to be consistent with our loss
         | 
| 186 | 
            +
                    # implementation: that is, number of auxiliary losses is always
         | 
| 187 | 
            +
                    # equal to number of decoder layers. With learnable query features, the number of
         | 
| 188 | 
            +
                    # auxiliary losses equals number of decoders plus 1.
         | 
| 189 | 
            +
                    assert dec_cfg['DEC_LAYERS'] >= 1
         | 
| 190 | 
            +
                    ret["dec_layers"] = dec_cfg['DEC_LAYERS'] - 1
         | 
| 191 | 
            +
                    ret["pre_norm"] = dec_cfg['PRE_NORM']
         | 
| 192 | 
            +
                    ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ']
         | 
| 193 | 
            +
                    ret["mask_dim"] = enc_cfg['MASK_DIM']
         | 
| 194 | 
            +
                    ret["task_switch"] = extra['task_switch']
         | 
| 195 | 
            +
                    ret["max_spatial_len"] = dec_cfg['MAX_SPATIAL_LEN']
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    # attn data struct
         | 
| 198 | 
            +
                    ret["attn_arch"] = cfg['ATTENTION_ARCH']
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    return ret
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                def forward(self, x, mask_features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
         | 
| 203 | 
            +
                    # x is a list of multi-scale feature
         | 
| 204 | 
            +
                    assert len(x) == self.num_feature_levels; del mask
         | 
| 205 | 
            +
                    spatial_extra_flag = 'spatial_query_pos_mask' in extra.keys() or task == 'refimg'
         | 
| 206 | 
            +
                    grounding_extra_flag = 'grounding_tokens' in extra.keys()
         | 
| 207 | 
            +
                    visual_extra_flag = 'visual_query_pos' in extra.keys()
         | 
| 208 | 
            +
                    audio_extra_flag = 'audio_tokens' in extra.keys()
         | 
| 209 | 
            +
                    spatial_memory_flag = 'prev_mask' in extra.keys()
         | 
| 210 | 
            +
                    flags = {"spatial": spatial_extra_flag, "grounding": grounding_extra_flag, "memories_spatial": spatial_memory_flag, "visual": visual_extra_flag, "audio": audio_extra_flag}
         | 
| 211 | 
            +
                    self.attention_data.reset(flags, task, extra)
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    src, pos, size_list = prepare_features(x, self.num_feature_levels, self.pe_layer, self.input_proj, self.level_embed)
         | 
| 214 | 
            +
                    _, bs, _ = src[0].shape
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    # QxNxC
         | 
| 217 | 
            +
                    query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
         | 
| 218 | 
            +
                    output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
         | 
| 219 | 
            +
                    self.attention_data.set('queries_object', 'queries', output, query_embed)
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                    if self.task_switch['spatial'] and spatial_extra_flag:
         | 
| 222 | 
            +
                        # get divisor
         | 
| 223 | 
            +
                        _,h,w = extra['spatial_query_pos_mask'][0].shape
         | 
| 224 | 
            +
                        divisor = torch.tensor([h,w], device=output.device)[None,]
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                        # Get mean pos spatial query
         | 
| 227 | 
            +
                        non_zero_pos_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_pos_mask']]
         | 
| 228 | 
            +
                        non_zero_pos_point = nn.utils.rnn.pad_sequence(non_zero_pos_point, padding_value=-1).permute(1,0,2)
         | 
| 229 | 
            +
                        non_zero_pos_mask = (non_zero_pos_point.sum(dim=-1) < 0)
         | 
| 230 | 
            +
                        spatial_query_pos = point_sample(mask_features, non_zero_pos_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
         | 
| 231 | 
            +
                        spatial_query_pos = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_pos.transpose(1,2), ~non_zero_pos_mask)]).transpose(0,1).nan_to_num()
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                        # Get mean neg spatial query
         | 
| 234 | 
            +
                        non_zero_neg_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_neg_mask']]
         | 
| 235 | 
            +
                        non_zero_neg_point = nn.utils.rnn.pad_sequence(non_zero_neg_point, padding_value=-1).permute(1,0,2)
         | 
| 236 | 
            +
                        non_zero_neg_mask = (non_zero_neg_point.sum(dim=-1) < 0)
         | 
| 237 | 
            +
                        spatial_query_neg = point_sample(mask_features, non_zero_neg_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
         | 
| 238 | 
            +
                        spatial_query_neg = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_neg.transpose(1,2), ~non_zero_neg_mask)]).transpose(0,1).nan_to_num()
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                        # merge positive and negative sample points for self attention
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                        # Get layerwise spatial query
         | 
| 243 | 
            +
                        src_spatial_queries = []
         | 
| 244 | 
            +
                        src_spatial_maskings = []
         | 
| 245 | 
            +
                        for i in range(len(src)):
         | 
| 246 | 
            +
                            hw,_,dc = src[i].shape
         | 
| 247 | 
            +
                            src_mask_features = src[i].view(size_list[i][0],size_list[i][1],bs,dc)
         | 
| 248 | 
            +
                            src_mask_features = src_mask_features @ self.mask_sptial_embed[i]
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                            non_zero_query_point_pos = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_pos_mask']]
         | 
| 251 | 
            +
                            non_zero_query_point_neg = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_neg_mask']]
         | 
| 252 | 
            +
                            non_zero_query_point = [torch.cat([x,y], dim=0) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                            pos_neg_indicator = [torch.cat([torch.ones(x.shape[0], device=x.device), -torch.ones(y.shape[0], device=y.device)]) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
         | 
| 255 | 
            +
                            pos_neg_indicator = nn.utils.rnn.pad_sequence(pos_neg_indicator, padding_value=0)
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                            non_zero_query_point = nn.utils.rnn.pad_sequence(non_zero_query_point, padding_value=-1).permute(1,0,2)
         | 
| 258 | 
            +
                            non_zero_query_mask = (non_zero_query_point.sum(dim=-1) < 0)
         | 
| 259 | 
            +
                            non_zero_query_point[non_zero_query_mask] = 0
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                            spatial_tokens = point_sample(src_mask_features.permute(2,3,0,1), non_zero_query_point.flip(dims=(2,)).type(src_mask_features.dtype), align_corners=True).permute(2,0,1)
         | 
| 262 | 
            +
                            spatial_tokens[pos_neg_indicator==1] += self.pn_indicator.weight[0:1]
         | 
| 263 | 
            +
                            spatial_tokens[pos_neg_indicator==-1] += self.pn_indicator.weight[1:2]
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                            src_spatial_queries += [spatial_tokens]
         | 
| 266 | 
            +
                            src_spatial_maskings += [non_zero_query_mask]
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                        if 'refimg' in task:
         | 
| 269 | 
            +
                            output_refimg = {}
         | 
| 270 | 
            +
                            output_refimg['visual_query_pos'] = spatial_query_pos
         | 
| 271 | 
            +
                            output_refimg['visual_query_neg'] = spatial_query_neg
         | 
| 272 | 
            +
                            output_refimg['src_visual_queries'] = src_spatial_queries
         | 
| 273 | 
            +
                            output_refimg['src_visual_maskings'] = src_spatial_maskings
         | 
| 274 | 
            +
                            return output_refimg
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                        if task != 'demo':
         | 
| 277 | 
            +
                            # Get object query for spatial index
         | 
| 278 | 
            +
                            self.attention_data.set('queries_spatial', 'queries')
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                    if self.task_switch['visual'] and visual_extra_flag:
         | 
| 281 | 
            +
                        visual_query_pos = extra['visual_query_pos']
         | 
| 282 | 
            +
                        visual_query_neg = extra['visual_query_neg']
         | 
| 283 | 
            +
                        src_visual_queries = extra['src_visual_queries']
         | 
| 284 | 
            +
                        src_visual_maskings = extra['src_visual_maskings']
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    if self.task_switch['grounding'] and grounding_extra_flag:
         | 
| 287 | 
            +
                        # Get grounding tokens
         | 
| 288 | 
            +
                        grounding_tokens = extra['grounding_tokens']
         | 
| 289 | 
            +
                        _grounding_tokens = grounding_tokens.detach().clone()
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                        self.attention_data.set('tokens_grounding', 'tokens', grounding_tokens, _grounding_tokens)
         | 
| 292 | 
            +
                        self.attention_data.set_maskings('tokens_grounding', extra['grounding_nonzero_mask'])
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                    if self.task_switch['audio'] and audio_extra_flag:
         | 
| 295 | 
            +
                        # Get grounding tokens
         | 
| 296 | 
            +
                        grounding_tokens = extra['audio_tokens']
         | 
| 297 | 
            +
                        _grounding_tokens = grounding_tokens.detach().clone()
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                        self.attention_data.set('tokens_audio', 'tokens', grounding_tokens, _grounding_tokens)
         | 
| 300 | 
            +
                        self.attention_data.set_maskings('tokens_audio', extra['audio_nonzero_mask'])
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                    output, query_embed = self.attention_data.cross_attn_variables()
         | 
| 303 | 
            +
                    # prediction heads on learnable query features
         | 
| 304 | 
            +
                    results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0])
         | 
| 305 | 
            +
                    results["predictions_pos_spatial"] = spatial_query_pos.transpose(0,1) if spatial_extra_flag else None
         | 
| 306 | 
            +
                    results["predictions_neg_spatial"] = spatial_query_neg.transpose(0,1) if spatial_extra_flag else None
         | 
| 307 | 
            +
                    results["predictions_pos_visual"] = visual_query_pos.transpose(0,1) if visual_extra_flag else None
         | 
| 308 | 
            +
                    results["predictions_neg_visual"] = visual_query_neg.transpose(0,1) if visual_extra_flag else None
         | 
| 309 | 
            +
                    self.attention_data.set_results(results)
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                    for i in range(self.num_layers):
         | 
| 312 | 
            +
                        level_index = i % self.num_feature_levels
         | 
| 313 | 
            +
                        # CROSS ATTENTION
         | 
| 314 | 
            +
                        output, avg_attn = self.transformer_cross_attention_layers[i](
         | 
| 315 | 
            +
                            output, src[level_index],
         | 
| 316 | 
            +
                            memory_mask=self.attention_data.cross_attn_mask(size_list[level_index], self.num_heads),
         | 
| 317 | 
            +
                            memory_key_padding_mask=None,  # here we do not apply masking on padded region
         | 
| 318 | 
            +
                            pos=pos[level_index], query_pos=query_embed
         | 
| 319 | 
            +
                        )
         | 
| 320 | 
            +
                        self.attention_data.update_variables(output, 'cross_attn')
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                        # SELF ATTENTION
         | 
| 323 | 
            +
                        self_attn_mask = torch.zeros((bs, self.num_queries, self.num_queries), device=query_embed.device).bool() # Default False (attend oq)
         | 
| 324 | 
            +
                        if self.task_switch['spatial'] and spatial_extra_flag:
         | 
| 325 | 
            +
                            # get spatial tokens
         | 
| 326 | 
            +
                            spatial_tokens = src_spatial_queries[level_index]
         | 
| 327 | 
            +
                            _spatial_tokens = spatial_tokens.detach().clone()
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                            self.attention_data.set('tokens_spatial', 'tokens', spatial_tokens, _spatial_tokens)
         | 
| 330 | 
            +
                            self.attention_data.set_maskings('tokens_spatial', src_spatial_maskings[level_index])
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                        if self.task_switch['visual'] and visual_extra_flag:
         | 
| 333 | 
            +
                            # get spatial tokens
         | 
| 334 | 
            +
                            visual_tokens = src_visual_queries[level_index]
         | 
| 335 | 
            +
                            _visual_tokens = visual_tokens.detach().clone()
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                            self.attention_data.set('tokens_visual', 'tokens', visual_tokens, _visual_tokens)
         | 
| 338 | 
            +
                            self.attention_data.set_maskings('tokens_visual', src_visual_maskings[level_index])
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                        output, query_embed, self_attn_mask = self.attention_data.self_attn(bs, self.num_heads)
         | 
| 341 | 
            +
                        output = self.transformer_self_attention_layers[i](
         | 
| 342 | 
            +
                            output, tgt_mask=self_attn_mask,
         | 
| 343 | 
            +
                            tgt_key_padding_mask=None,
         | 
| 344 | 
            +
                            query_pos=query_embed)
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                        # FFN
         | 
| 347 | 
            +
                        output = self.transformer_ffn_layers[i](
         | 
| 348 | 
            +
                            output
         | 
| 349 | 
            +
                        )
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                        self.attention_data.update_variables(output, 'self_attn')
         | 
| 352 | 
            +
                        output, query_embed = self.attention_data.cross_attn_variables()
         | 
| 353 | 
            +
                        results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i)
         | 
| 354 | 
            +
                        results["predictions_pos_spatial"] = spatial_query_pos.transpose(0,1) if spatial_extra_flag else None
         | 
| 355 | 
            +
                        results["predictions_neg_spatial"] = spatial_query_neg.transpose(0,1) if spatial_extra_flag else None
         | 
| 356 | 
            +
                        results["predictions_pos_visual"] = visual_query_pos.transpose(0,1) if visual_extra_flag else None
         | 
| 357 | 
            +
                        results["predictions_neg_visual"] = visual_query_neg.transpose(0,1) if visual_extra_flag else None
         | 
| 358 | 
            +
                        self.attention_data.set_results(results)
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                    return self.attention_data.organize_output()
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                def forward_prediction_heads(self, output, mask_features, attn_mask_target_size, layer_id=-1):
         | 
| 363 | 
            +
                    decoder_output = self.decoder_norm(output)
         | 
| 364 | 
            +
                    decoder_output = decoder_output.transpose(0, 1)
         | 
| 365 | 
            +
                    class_embed = decoder_output @ self.class_embed
         | 
| 366 | 
            +
                    outputs_class = self.lang_encoder.compute_similarity(class_embed)
         | 
| 367 | 
            +
                    mask_embed = self.mask_embed(decoder_output)
         | 
| 368 | 
            +
                    outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
         | 
| 369 | 
            +
                    
         | 
| 370 | 
            +
                    outputs_bbox = [None for i in range(len(outputs_mask))]
         | 
| 371 | 
            +
                    if self.task_switch['bbox']:
         | 
| 372 | 
            +
                        outputs_bbox = self.bbox_embed(decoder_output)
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                    # NOTE: prediction is of higher-resolution
         | 
| 375 | 
            +
                    # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
         | 
| 376 | 
            +
                    attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False)
         | 
| 377 | 
            +
             | 
| 378 | 
            +
                    # must use bool type
         | 
| 379 | 
            +
                    # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
         | 
| 380 | 
            +
                    attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
         | 
| 381 | 
            +
                    attn_mask = attn_mask.detach()
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                    outputs_caption = class_embed
         | 
| 384 | 
            +
             | 
| 385 | 
            +
                    results = {
         | 
| 386 | 
            +
                        "attn_mask": attn_mask,
         | 
| 387 | 
            +
                        "predictions_class": outputs_class,
         | 
| 388 | 
            +
                        "predictions_mask": outputs_mask,
         | 
| 389 | 
            +
                        "predictions_bbox": outputs_bbox,
         | 
| 390 | 
            +
                        "predictions_caption": outputs_caption,
         | 
| 391 | 
            +
                        "predictions_maskemb": mask_embed,
         | 
| 392 | 
            +
                    }
         | 
| 393 | 
            +
                    return results
         | 
| 394 | 
            +
             | 
| 395 | 
            +
            @register_decoder
         | 
| 396 | 
            +
            def get_seem_interface(cfg, in_channels, lang_encoder, mask_classification, extra):
         | 
| 397 | 
            +
                return SEEMDecoder(cfg, in_channels, lang_encoder, mask_classification, extra)
         | 
    	
        modeling/interface/seem_v0.py
    ADDED
    
    | @@ -0,0 +1,392 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # --------------------------------------------------------
         | 
| 2 | 
            +
            # SEEM -- Segment Everything Everywhere All at Once
         | 
| 3 | 
            +
            # Licensed under The Apache License 2.0 [see LICENSE for details]
         | 
| 4 | 
            +
            # Written by Xueyan Zou (xueyan@cs.wisc.edu)
         | 
| 5 | 
            +
            # --------------------------------------------------------
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import logging
         | 
| 8 | 
            +
            from typing import Optional
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            from torch import nn, Tensor
         | 
| 12 | 
            +
            from torch.nn import functional as F
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from timm.models.layers import trunc_normal_
         | 
| 15 | 
            +
            from detectron2.layers import Conv2d
         | 
| 16 | 
            +
            import fvcore.nn.weight_init as weight_init
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from .build import register_decoder
         | 
| 19 | 
            +
            from .modules import SelfAttentionLayer, CrossAttentionLayer, FFNLayer, MLP
         | 
| 20 | 
            +
            from .prototype.attention_data_struct_seemv0 import AttentionDataStruct
         | 
| 21 | 
            +
            from ..utils import rand_sample_plain as rand_sample
         | 
| 22 | 
            +
            from ..utils import prepare_features, configurable
         | 
| 23 | 
            +
            from ..modules import PositionEmbeddingSine
         | 
| 24 | 
            +
            from ..modules.point_features import point_sample
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            class SEEMDecoder(nn.Module):
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                @configurable
         | 
| 30 | 
            +
                def __init__(
         | 
| 31 | 
            +
                    self,
         | 
| 32 | 
            +
                    lang_encoder: nn.Module,
         | 
| 33 | 
            +
                    in_channels,
         | 
| 34 | 
            +
                    mask_classification=True,
         | 
| 35 | 
            +
                    *,
         | 
| 36 | 
            +
                    hidden_dim: int,
         | 
| 37 | 
            +
                    dim_proj: int,
         | 
| 38 | 
            +
                    num_queries: int,
         | 
| 39 | 
            +
                    contxt_len: int,
         | 
| 40 | 
            +
                    nheads: int,
         | 
| 41 | 
            +
                    dim_feedforward: int,
         | 
| 42 | 
            +
                    dec_layers: int,
         | 
| 43 | 
            +
                    pre_norm: bool,
         | 
| 44 | 
            +
                    mask_dim: int,
         | 
| 45 | 
            +
                    task_switch: dict,
         | 
| 46 | 
            +
                    enforce_input_project: bool,
         | 
| 47 | 
            +
                    max_spatial_len: int,
         | 
| 48 | 
            +
                    attn_arch: dict,
         | 
| 49 | 
            +
                ):
         | 
| 50 | 
            +
                    """
         | 
| 51 | 
            +
                    NOTE: this interface is experimental.
         | 
| 52 | 
            +
                    Args:
         | 
| 53 | 
            +
                        in_channels: channels of the input features
         | 
| 54 | 
            +
                        mask_classification: whether to add mask classifier or not
         | 
| 55 | 
            +
                        num_classes: number of classes
         | 
| 56 | 
            +
                        hidden_dim: Transformer feature dimension
         | 
| 57 | 
            +
                        num_queries: number of queries
         | 
| 58 | 
            +
                        nheads: number of heads
         | 
| 59 | 
            +
                        dim_feedforward: feature dimension in feedforward network
         | 
| 60 | 
            +
                        enc_layers: number of Transformer encoder layers
         | 
| 61 | 
            +
                        dec_layers: number of Transformer decoder layers
         | 
| 62 | 
            +
                        pre_norm: whether to use pre-LayerNorm or not
         | 
| 63 | 
            +
                        mask_dim: mask feature dimension
         | 
| 64 | 
            +
                        enforce_input_project: add input project 1x1 conv even if input
         | 
| 65 | 
            +
                            channels and hidden dim is identical
         | 
| 66 | 
            +
                    """
         | 
| 67 | 
            +
                    super().__init__()
         | 
| 68 | 
            +
                    assert mask_classification, "Only support mask classification model"
         | 
| 69 | 
            +
                    self.mask_classification = mask_classification
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    # positional encoding
         | 
| 72 | 
            +
                    N_steps = hidden_dim // 2
         | 
| 73 | 
            +
                    self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
         | 
| 74 | 
            +
                    
         | 
| 75 | 
            +
                    # define Transformer decoder here
         | 
| 76 | 
            +
                    self.num_heads = nheads
         | 
| 77 | 
            +
                    self.num_layers = dec_layers
         | 
| 78 | 
            +
                    self.contxt_len = contxt_len
         | 
| 79 | 
            +
                    self.transformer_self_attention_layers = nn.ModuleList()
         | 
| 80 | 
            +
                    self.transformer_cross_attention_layers = nn.ModuleList()
         | 
| 81 | 
            +
                    self.transformer_ffn_layers = nn.ModuleList()
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    for _ in range(self.num_layers):
         | 
| 84 | 
            +
                        self.transformer_self_attention_layers.append(
         | 
| 85 | 
            +
                            SelfAttentionLayer(
         | 
| 86 | 
            +
                                d_model=hidden_dim,
         | 
| 87 | 
            +
                                nhead=nheads,
         | 
| 88 | 
            +
                                dropout=0.0,
         | 
| 89 | 
            +
                                normalize_before=pre_norm,
         | 
| 90 | 
            +
                            )
         | 
| 91 | 
            +
                        )
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                        self.transformer_cross_attention_layers.append(
         | 
| 94 | 
            +
                            CrossAttentionLayer(
         | 
| 95 | 
            +
                                d_model=hidden_dim,
         | 
| 96 | 
            +
                                nhead=nheads,
         | 
| 97 | 
            +
                                dropout=0.0,
         | 
| 98 | 
            +
                                normalize_before=pre_norm,
         | 
| 99 | 
            +
                            )
         | 
| 100 | 
            +
                        )
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                        self.transformer_ffn_layers.append(
         | 
| 103 | 
            +
                            FFNLayer(
         | 
| 104 | 
            +
                                d_model=hidden_dim,
         | 
| 105 | 
            +
                                dim_feedforward=dim_feedforward,
         | 
| 106 | 
            +
                                dropout=0.0,
         | 
| 107 | 
            +
                                normalize_before=pre_norm,
         | 
| 108 | 
            +
                            )
         | 
| 109 | 
            +
                        )
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    self.decoder_norm = nn.LayerNorm(hidden_dim)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    self.num_queries = num_queries
         | 
| 114 | 
            +
                    # learnable query features
         | 
| 115 | 
            +
                    self.query_feat = nn.Embedding(num_queries, hidden_dim)
         | 
| 116 | 
            +
                    # learnable query p.e.
         | 
| 117 | 
            +
                    self.query_embed = nn.Embedding(num_queries, hidden_dim)
         | 
| 118 | 
            +
                    
         | 
| 119 | 
            +
                    # level embedding (we always use 3 scales)
         | 
| 120 | 
            +
                    self.num_feature_levels = 3
         | 
| 121 | 
            +
                    self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
         | 
| 122 | 
            +
                    self.input_proj = nn.ModuleList()
         | 
| 123 | 
            +
                    
         | 
| 124 | 
            +
                    for _ in range(self.num_feature_levels):
         | 
| 125 | 
            +
                        if in_channels != hidden_dim or enforce_input_project:
         | 
| 126 | 
            +
                            self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
         | 
| 127 | 
            +
                            weight_init.c2_xavier_fill(self.input_proj[-1])
         | 
| 128 | 
            +
                        else:
         | 
| 129 | 
            +
                            self.input_proj.append(nn.Sequential())
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    self.task_switch = task_switch
         | 
| 132 | 
            +
                    self.query_index = {}
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    # output FFNs
         | 
| 135 | 
            +
                    self.lang_encoder = lang_encoder
         | 
| 136 | 
            +
                    self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
         | 
| 137 | 
            +
                    self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
         | 
| 138 | 
            +
                    trunc_normal_(self.class_embed, std=.02)
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    if task_switch['bbox']:
         | 
| 141 | 
            +
                        self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    if task_switch['spatial']:
         | 
| 144 | 
            +
                        # spatial query
         | 
| 145 | 
            +
                        self.mask_sptial_embed = nn.ParameterList([nn.Parameter(torch.empty(hidden_dim, hidden_dim)) for x in range(3)])
         | 
| 146 | 
            +
                        trunc_normal_(self.mask_sptial_embed[0], std=.02)
         | 
| 147 | 
            +
                        trunc_normal_(self.mask_sptial_embed[1], std=.02)
         | 
| 148 | 
            +
                        trunc_normal_(self.mask_sptial_embed[2], std=.02)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                        self.max_spatial_len = max_spatial_len
         | 
| 151 | 
            +
                        # spatial memory
         | 
| 152 | 
            +
                        num_spatial_memories = attn_arch['SPATIAL_MEMORIES']
         | 
| 153 | 
            +
                        self.spatial_embed = nn.Embedding(num_spatial_memories, hidden_dim)
         | 
| 154 | 
            +
                        self.spatial_featured = nn.Embedding(num_spatial_memories, hidden_dim)
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                        # learnable positive negative indicator
         | 
| 157 | 
            +
                        self.pn_indicator = nn.Embedding(2, hidden_dim)
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    # build AttentionDataStruct
         | 
| 160 | 
            +
                    attn_arch['NUM_LAYERS'] = self.num_layers
         | 
| 161 | 
            +
                    self.attention_data = AttentionDataStruct(attn_arch, task_switch)
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                @classmethod
         | 
| 164 | 
            +
                def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra):
         | 
| 165 | 
            +
                    ret = {}
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                    ret["lang_encoder"] = lang_encoder
         | 
| 168 | 
            +
                    ret["in_channels"] = in_channels
         | 
| 169 | 
            +
                    ret["mask_classification"] = mask_classification
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    enc_cfg = cfg['MODEL']['ENCODER']
         | 
| 172 | 
            +
                    dec_cfg = cfg['MODEL']['DECODER']
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    ret["hidden_dim"] = dec_cfg['HIDDEN_DIM']
         | 
| 175 | 
            +
                    ret["dim_proj"] = cfg['MODEL']['DIM_PROJ']
         | 
| 176 | 
            +
                    ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES']
         | 
| 177 | 
            +
                    ret["contxt_len"] = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    # Transformer parameters:
         | 
| 180 | 
            +
                    ret["nheads"] = dec_cfg['NHEADS']
         | 
| 181 | 
            +
                    ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD']
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    # NOTE: because we add learnable query features which requires supervision,
         | 
| 184 | 
            +
                    # we add minus 1 to decoder layers to be consistent with our loss
         | 
| 185 | 
            +
                    # implementation: that is, number of auxiliary losses is always
         | 
| 186 | 
            +
                    # equal to number of decoder layers. With learnable query features, the number of
         | 
| 187 | 
            +
                    # auxiliary losses equals number of decoders plus 1.
         | 
| 188 | 
            +
                    assert dec_cfg['DEC_LAYERS'] >= 1
         | 
| 189 | 
            +
                    ret["dec_layers"] = dec_cfg['DEC_LAYERS'] - 1
         | 
| 190 | 
            +
                    ret["pre_norm"] = dec_cfg['PRE_NORM']
         | 
| 191 | 
            +
                    ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ']
         | 
| 192 | 
            +
                    ret["mask_dim"] = enc_cfg['MASK_DIM']
         | 
| 193 | 
            +
                    ret["task_switch"] = extra['task_switch']
         | 
| 194 | 
            +
                    ret["max_spatial_len"] = dec_cfg['MAX_SPATIAL_LEN']
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    # attn data struct
         | 
| 197 | 
            +
                    ret["attn_arch"] = cfg['ATTENTION_ARCH']
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    return ret
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                def forward(self, x, mask_features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
         | 
| 202 | 
            +
                    # x is a list of multi-scale feature
         | 
| 203 | 
            +
                    assert len(x) == self.num_feature_levels; del mask
         | 
| 204 | 
            +
                    spatial_extra_flag = 'spatial_query_pos_mask' in extra.keys() or task == 'refimg' or 'refimg_tokens' in extra
         | 
| 205 | 
            +
                    grounding_extra_flag = 'grounding_tokens' in extra.keys()
         | 
| 206 | 
            +
                    spatial_memory_flag = 'prev_mask' in extra.keys()
         | 
| 207 | 
            +
                    flags = {"spatial": spatial_extra_flag, "grounding": grounding_extra_flag, "memories_spatial": spatial_memory_flag}
         | 
| 208 | 
            +
                    self.attention_data.reset(flags, task, extra)
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    src, pos, size_list = prepare_features(x, self.num_feature_levels, self.pe_layer, self.input_proj, self.level_embed)
         | 
| 211 | 
            +
                    _, bs, _ = src[0].shape
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    # QxNxC
         | 
| 214 | 
            +
                    query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
         | 
| 215 | 
            +
                    output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
         | 
| 216 | 
            +
                    self.attention_data.set('queries_object', 'queries', output, query_embed)
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    if self.task_switch['spatial'] and spatial_extra_flag:
         | 
| 219 | 
            +
                        if 'refimg_tokens' not in extra:
         | 
| 220 | 
            +
                            # get divisor
         | 
| 221 | 
            +
                            _,h,w = extra['spatial_query_pos_mask'][0].shape
         | 
| 222 | 
            +
                            divisor = torch.tensor([h,w], device=output.device)[None,]
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                            # Get mean pos spatial query
         | 
| 225 | 
            +
                            non_zero_pos_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_pos_mask']]
         | 
| 226 | 
            +
                            non_zero_pos_point = nn.utils.rnn.pad_sequence(non_zero_pos_point, padding_value=-1).permute(1,0,2)
         | 
| 227 | 
            +
                            non_zero_pos_mask = (non_zero_pos_point.sum(dim=-1) < 0)
         | 
| 228 | 
            +
                            spatial_query_pos = point_sample(mask_features, non_zero_pos_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
         | 
| 229 | 
            +
                            spatial_query_pos = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_pos.transpose(1,2), ~non_zero_pos_mask)]).transpose(0,1).nan_to_num()
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                            # Get mean neg spatial query
         | 
| 232 | 
            +
                            non_zero_neg_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_neg_mask']]
         | 
| 233 | 
            +
                            non_zero_neg_point = nn.utils.rnn.pad_sequence(non_zero_neg_point, padding_value=-1).permute(1,0,2)
         | 
| 234 | 
            +
                            non_zero_neg_mask = (non_zero_neg_point.sum(dim=-1) < 0)
         | 
| 235 | 
            +
                            spatial_query_neg = point_sample(mask_features, non_zero_neg_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
         | 
| 236 | 
            +
                            spatial_query_neg = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_neg.transpose(1,2), ~non_zero_neg_mask)]).transpose(0,1).nan_to_num()
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                            # merge positive and negative sample points for self attention
         | 
| 239 | 
            +
                            # pos_neg_points = [x|y for x,y in zip(extra['spatial_query_pos_mask'], extra['spatial_query_neg_mask'])]
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                            # Get layerwise spatial query
         | 
| 242 | 
            +
                            src_spatial_queries = []
         | 
| 243 | 
            +
                            src_spatial_maskings = []
         | 
| 244 | 
            +
                            for i in range(len(src)):
         | 
| 245 | 
            +
                                hw,_,dc = src[i].shape
         | 
| 246 | 
            +
                                src_mask_features = src[i].view(size_list[i][0],size_list[i][1],bs,dc)
         | 
| 247 | 
            +
                                src_mask_features = src_mask_features @ self.mask_sptial_embed[i]
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                                non_zero_query_point_pos = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_pos_mask']]
         | 
| 250 | 
            +
                                non_zero_query_point_neg = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_neg_mask']]
         | 
| 251 | 
            +
                                non_zero_query_point = [torch.cat([x,y], dim=0) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                                pos_neg_indicator = [torch.cat([torch.ones(x.shape[0], device=x.device), -torch.ones(y.shape[0], device=y.device)]) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
         | 
| 254 | 
            +
                                pos_neg_indicator = nn.utils.rnn.pad_sequence(pos_neg_indicator, padding_value=0)
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                                non_zero_query_point = nn.utils.rnn.pad_sequence(non_zero_query_point, padding_value=-1).permute(1,0,2)
         | 
| 257 | 
            +
                                non_zero_query_mask = (non_zero_query_point.sum(dim=-1) < 0)
         | 
| 258 | 
            +
                                non_zero_query_point[non_zero_query_mask] = 0
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                                spatial_tokens = point_sample(src_mask_features.permute(2,3,0,1), non_zero_query_point.flip(dims=(2,)).type(src_mask_features.dtype), align_corners=True).permute(2,0,1)
         | 
| 261 | 
            +
                                spatial_tokens[pos_neg_indicator==1] += self.pn_indicator.weight[0:1]
         | 
| 262 | 
            +
                                spatial_tokens[pos_neg_indicator==-1] += self.pn_indicator.weight[1:2]
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                                src_spatial_queries += [spatial_tokens]
         | 
| 265 | 
            +
                                src_spatial_maskings += [non_zero_query_mask]
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                            if 'refimg' in task:
         | 
| 268 | 
            +
                                output_refimg = {}
         | 
| 269 | 
            +
                                output_refimg['spatial_query_pos'] = spatial_query_pos
         | 
| 270 | 
            +
                                output_refimg['spatial_query_neg'] = spatial_query_neg
         | 
| 271 | 
            +
                                output_refimg['src_spatial_queries'] = src_spatial_queries
         | 
| 272 | 
            +
                                output_refimg['src_spatial_maskings'] = src_spatial_maskings
         | 
| 273 | 
            +
                                return output_refimg
         | 
| 274 | 
            +
                        else:
         | 
| 275 | 
            +
                            spatial_query_pos = extra['refimg_tokens']['spatial_query_pos']
         | 
| 276 | 
            +
                            spatial_query_neg = extra['refimg_tokens']['spatial_query_neg']
         | 
| 277 | 
            +
                            src_spatial_queries = extra['refimg_tokens']['src_spatial_queries']
         | 
| 278 | 
            +
                            src_spatial_maskings = extra['refimg_tokens']['src_spatial_maskings']
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                        # Get object query for spatial index
         | 
| 281 | 
            +
                        self.attention_data.set('queries_spatial', 'queries')
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                        # set spatial memory
         | 
| 284 | 
            +
                        spatial_output = self.spatial_featured.weight.unsqueeze(1).repeat(1, bs, 1)
         | 
| 285 | 
            +
                        spatial_embed = self.spatial_embed.weight.unsqueeze(1).repeat(1, bs, 1)
         | 
| 286 | 
            +
                        self.attention_data.set('memories_spatial', 'memories', spatial_output, spatial_embed)
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                        # if 'queries_spatial' in extra:
         | 
| 289 | 
            +
                        #     self.attention_data.set('queries_spatial', 'queries', var=extra['queries_spatial'])
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                        # if spatial_memory_flag:
         | 
| 292 | 
            +
                        #     prev_mask = (extra['prev_mask'].sigmoid() > 0.5).detach()
         | 
| 293 | 
            +
                        #     non_zero_query_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in prev_mask]
         | 
| 294 | 
            +
                        #     non_zero_query_point = nn.utils.rnn.pad_sequence(non_zero_query_point, padding_value=-1).permute(1,0,2)
         | 
| 295 | 
            +
                        #     non_zero_query_mask = (non_zero_query_point.sum(dim=-1) < 0)
         | 
| 296 | 
            +
                        #     spatial_memory = point_sample(mask_features, non_zero_query_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
         | 
| 297 | 
            +
                        #     spatial_memory = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_memory.transpose(1,2), ~non_zero_query_mask)]).transpose(0,1).nan_to_num()
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                    if self.task_switch['grounding'] and grounding_extra_flag:
         | 
| 300 | 
            +
                        # Get grounding tokens
         | 
| 301 | 
            +
                        grounding_tokens = extra['grounding_tokens']
         | 
| 302 | 
            +
                        _grounding_tokens = grounding_tokens.detach().clone()
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                        self.attention_data.set('tokens_grounding', 'tokens', grounding_tokens, _grounding_tokens)
         | 
| 305 | 
            +
                        self.attention_data.set('queries_grounding', 'queries')
         | 
| 306 | 
            +
                        self.attention_data.set_maskings('tokens_grounding', extra['grounding_nonzero_mask'])
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                    output, query_embed = self.attention_data.cross_attn_variables()
         | 
| 309 | 
            +
                    # prediction heads on learnable query features
         | 
| 310 | 
            +
                    results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0])
         | 
| 311 | 
            +
                    results["predictions_pos_spatial"] = spatial_query_pos.transpose(0,1) if spatial_extra_flag else None
         | 
| 312 | 
            +
                    results["predictions_neg_spatial"] = spatial_query_neg.transpose(0,1) if spatial_extra_flag else None
         | 
| 313 | 
            +
                    self.attention_data.set_results(results)
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    for i in range(self.num_layers):
         | 
| 316 | 
            +
                        level_index = i % self.num_feature_levels
         | 
| 317 | 
            +
                        # CROSS ATTENTION
         | 
| 318 | 
            +
                        output, avg_attn = self.transformer_cross_attention_layers[i](
         | 
| 319 | 
            +
                            output, src[level_index],
         | 
| 320 | 
            +
                            memory_mask=self.attention_data.cross_attn_mask(size_list[level_index], self.num_heads),
         | 
| 321 | 
            +
                            memory_key_padding_mask=None,  # here we do not apply masking on padded region
         | 
| 322 | 
            +
                            pos=pos[level_index], query_pos=query_embed
         | 
| 323 | 
            +
                        )
         | 
| 324 | 
            +
                        self.attention_data.update_variables(output, 'cross_attn')
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                        # SELF ATTENTION
         | 
| 327 | 
            +
                        self_attn_mask = torch.zeros((bs, self.num_queries, self.num_queries), device=query_embed.device).bool() # Default False (attend oq)
         | 
| 328 | 
            +
                        if self.task_switch['spatial'] and spatial_extra_flag:
         | 
| 329 | 
            +
                            # get spatial tokens
         | 
| 330 | 
            +
                            spatial_tokens = src_spatial_queries[level_index]
         | 
| 331 | 
            +
                            _spatial_tokens = spatial_tokens.detach().clone()
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                            self.attention_data.set('tokens_spatial', 'tokens', spatial_tokens, _spatial_tokens)
         | 
| 334 | 
            +
                            self.attention_data.set_maskings('tokens_spatial', src_spatial_maskings[level_index])
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                        output, query_embed, self_attn_mask = self.attention_data.self_attn(bs, self.num_heads)
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                        output = self.transformer_self_attention_layers[i](
         | 
| 339 | 
            +
                            output, tgt_mask=self_attn_mask,
         | 
| 340 | 
            +
                            tgt_key_padding_mask=None,
         | 
| 341 | 
            +
                            query_pos=query_embed)
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                        # FFN
         | 
| 344 | 
            +
                        output = self.transformer_ffn_layers[i](
         | 
| 345 | 
            +
                            output
         | 
| 346 | 
            +
                        )
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                        self.attention_data.update_variables(output, 'self_attn')
         | 
| 349 | 
            +
                        output, query_embed = self.attention_data.cross_attn_variables()
         | 
| 350 | 
            +
                        results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i)
         | 
| 351 | 
            +
                        results["predictions_pos_spatial"] = spatial_query_pos.transpose(0,1) if spatial_extra_flag else None
         | 
| 352 | 
            +
                        results["predictions_neg_spatial"] = spatial_query_neg.transpose(0,1) if spatial_extra_flag else None
         | 
| 353 | 
            +
                        self.attention_data.set_results(results)
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                    return self.attention_data.organize_output()
         | 
| 356 | 
            +
             | 
| 357 | 
            +
                def forward_prediction_heads(self, output, mask_features, attn_mask_target_size, layer_id=-1):
         | 
| 358 | 
            +
                    decoder_output = self.decoder_norm(output)
         | 
| 359 | 
            +
                    decoder_output = decoder_output.transpose(0, 1)
         | 
| 360 | 
            +
                    class_embed = decoder_output @ self.class_embed
         | 
| 361 | 
            +
                    outputs_class = self.lang_encoder.compute_similarity(class_embed)
         | 
| 362 | 
            +
                    mask_embed = self.mask_embed(decoder_output)
         | 
| 363 | 
            +
                    outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
         | 
| 364 | 
            +
                    
         | 
| 365 | 
            +
                    outputs_bbox = [None for i in range(len(outputs_mask))]
         | 
| 366 | 
            +
                    if self.task_switch['bbox']:
         | 
| 367 | 
            +
                        outputs_bbox = self.bbox_embed(decoder_output)
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                    # NOTE: prediction is of higher-resolution
         | 
| 370 | 
            +
                    # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
         | 
| 371 | 
            +
                    attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False)
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                    # must use bool type
         | 
| 374 | 
            +
                    # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
         | 
| 375 | 
            +
                    attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
         | 
| 376 | 
            +
                    attn_mask = attn_mask.detach()
         | 
| 377 | 
            +
             | 
| 378 | 
            +
                    outputs_caption = class_embed
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                    results = {
         | 
| 381 | 
            +
                        "attn_mask": attn_mask,
         | 
| 382 | 
            +
                        "predictions_class": outputs_class,
         | 
| 383 | 
            +
                        "predictions_mask": outputs_mask,
         | 
| 384 | 
            +
                        "predictions_bbox": outputs_bbox,
         | 
| 385 | 
            +
                        "predictions_caption": outputs_caption,
         | 
| 386 | 
            +
                        "predictions_maskemb": mask_embed,
         | 
| 387 | 
            +
                    }
         | 
| 388 | 
            +
                    return results
         | 
| 389 | 
            +
             | 
| 390 | 
            +
            @register_decoder
         | 
| 391 | 
            +
            def get_seem_interface(cfg, in_channels, lang_encoder, mask_classification, extra):
         | 
| 392 | 
            +
                return SEEMDecoder(cfg, in_channels, lang_encoder, mask_classification, extra)
         | 
    	
        modeling/interface/seem_v1.py
    ADDED
    
    | @@ -0,0 +1,389 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # --------------------------------------------------------
         | 
| 2 | 
            +
            # SEEM -- Segment Everything Everywhere All at Once
         | 
| 3 | 
            +
            # Licensed under The Apache License 2.0 [see LICENSE for details]
         | 
| 4 | 
            +
            # Written by Xueyan Zou (xueyan@cs.wisc.edu)
         | 
| 5 | 
            +
            # --------------------------------------------------------
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import logging
         | 
| 8 | 
            +
            from typing import Optional
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            from torch import nn, Tensor
         | 
| 12 | 
            +
            from torch.nn import functional as F
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from timm.models.layers import trunc_normal_
         | 
| 15 | 
            +
            from detectron2.layers import Conv2d
         | 
| 16 | 
            +
            import fvcore.nn.weight_init as weight_init
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from .build import register_decoder
         | 
| 19 | 
            +
            from .modules import SelfAttentionLayer, CrossAttentionLayer, FFNLayer, MLP
         | 
| 20 | 
            +
            from .prototype.attention_data_struct_seemv1 import AttentionDataStruct
         | 
| 21 | 
            +
            from ..utils import rand_sample, prepare_features, configurable
         | 
| 22 | 
            +
            from ..modules import PositionEmbeddingSine
         | 
| 23 | 
            +
            from ..modules.point_features import point_sample
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            class SEEMDecoder(nn.Module):
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                @configurable
         | 
| 29 | 
            +
                def __init__(
         | 
| 30 | 
            +
                    self,
         | 
| 31 | 
            +
                    lang_encoder: nn.Module,
         | 
| 32 | 
            +
                    in_channels,
         | 
| 33 | 
            +
                    mask_classification=True,
         | 
| 34 | 
            +
                    *,
         | 
| 35 | 
            +
                    hidden_dim: int,
         | 
| 36 | 
            +
                    dim_proj: int,
         | 
| 37 | 
            +
                    num_queries: int,
         | 
| 38 | 
            +
                    contxt_len: int,
         | 
| 39 | 
            +
                    nheads: int,
         | 
| 40 | 
            +
                    dim_feedforward: int,
         | 
| 41 | 
            +
                    dec_layers: int,
         | 
| 42 | 
            +
                    pre_norm: bool,
         | 
| 43 | 
            +
                    mask_dim: int,
         | 
| 44 | 
            +
                    task_switch: dict,
         | 
| 45 | 
            +
                    enforce_input_project: bool,
         | 
| 46 | 
            +
                    max_spatial_len: int,
         | 
| 47 | 
            +
                    attn_arch: dict,
         | 
| 48 | 
            +
                ):
         | 
| 49 | 
            +
                    """
         | 
| 50 | 
            +
                    NOTE: this interface is experimental.
         | 
| 51 | 
            +
                    Args:
         | 
| 52 | 
            +
                        in_channels: channels of the input features
         | 
| 53 | 
            +
                        mask_classification: whether to add mask classifier or not
         | 
| 54 | 
            +
                        num_classes: number of classes
         | 
| 55 | 
            +
                        hidden_dim: Transformer feature dimension
         | 
| 56 | 
            +
                        num_queries: number of queries
         | 
| 57 | 
            +
                        nheads: number of heads
         | 
| 58 | 
            +
                        dim_feedforward: feature dimension in feedforward network
         | 
| 59 | 
            +
                        enc_layers: number of Transformer encoder layers
         | 
| 60 | 
            +
                        dec_layers: number of Transformer decoder layers
         | 
| 61 | 
            +
                        pre_norm: whether to use pre-LayerNorm or not
         | 
| 62 | 
            +
                        mask_dim: mask feature dimension
         | 
| 63 | 
            +
                        enforce_input_project: add input project 1x1 conv even if input
         | 
| 64 | 
            +
                            channels and hidden dim is identical
         | 
| 65 | 
            +
                    """
         | 
| 66 | 
            +
                    super().__init__()
         | 
| 67 | 
            +
                    assert mask_classification, "Only support mask classification model"
         | 
| 68 | 
            +
                    self.mask_classification = mask_classification
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    # positional encoding
         | 
| 71 | 
            +
                    N_steps = hidden_dim // 2
         | 
| 72 | 
            +
                    self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
         | 
| 73 | 
            +
                    
         | 
| 74 | 
            +
                    # define Transformer decoder here
         | 
| 75 | 
            +
                    self.num_heads = nheads
         | 
| 76 | 
            +
                    self.num_layers = dec_layers
         | 
| 77 | 
            +
                    self.contxt_len = contxt_len
         | 
| 78 | 
            +
                    self.transformer_self_attention_layers = nn.ModuleList()
         | 
| 79 | 
            +
                    self.transformer_cross_attention_layers = nn.ModuleList()
         | 
| 80 | 
            +
                    self.transformer_ffn_layers = nn.ModuleList()
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    for _ in range(self.num_layers):
         | 
| 83 | 
            +
                        self.transformer_self_attention_layers.append(
         | 
| 84 | 
            +
                            SelfAttentionLayer(
         | 
| 85 | 
            +
                                d_model=hidden_dim,
         | 
| 86 | 
            +
                                nhead=nheads,
         | 
| 87 | 
            +
                                dropout=0.0,
         | 
| 88 | 
            +
                                normalize_before=pre_norm,
         | 
| 89 | 
            +
                            )
         | 
| 90 | 
            +
                        )
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                        self.transformer_cross_attention_layers.append(
         | 
| 93 | 
            +
                            CrossAttentionLayer(
         | 
| 94 | 
            +
                                d_model=hidden_dim,
         | 
| 95 | 
            +
                                nhead=nheads,
         | 
| 96 | 
            +
                                dropout=0.0,
         | 
| 97 | 
            +
                                normalize_before=pre_norm,
         | 
| 98 | 
            +
                            )
         | 
| 99 | 
            +
                        )
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                        self.transformer_ffn_layers.append(
         | 
| 102 | 
            +
                            FFNLayer(
         | 
| 103 | 
            +
                                d_model=hidden_dim,
         | 
| 104 | 
            +
                                dim_feedforward=dim_feedforward,
         | 
| 105 | 
            +
                                dropout=0.0,
         | 
| 106 | 
            +
                                normalize_before=pre_norm,
         | 
| 107 | 
            +
                            )
         | 
| 108 | 
            +
                        )
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    self.decoder_norm = nn.LayerNorm(hidden_dim)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    self.num_queries = num_queries
         | 
| 113 | 
            +
                    # learnable query features
         | 
| 114 | 
            +
                    self.query_feat = nn.Embedding(num_queries, hidden_dim)
         | 
| 115 | 
            +
                    # learnable query p.e.
         | 
| 116 | 
            +
                    self.query_embed = nn.Embedding(num_queries, hidden_dim)
         | 
| 117 | 
            +
                    
         | 
| 118 | 
            +
                    # level embedding (we always use 3 scales)
         | 
| 119 | 
            +
                    self.num_feature_levels = 3
         | 
| 120 | 
            +
                    self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
         | 
| 121 | 
            +
                    self.input_proj = nn.ModuleList()
         | 
| 122 | 
            +
                    
         | 
| 123 | 
            +
                    for _ in range(self.num_feature_levels):
         | 
| 124 | 
            +
                        if in_channels != hidden_dim or enforce_input_project:
         | 
| 125 | 
            +
                            self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
         | 
| 126 | 
            +
                            weight_init.c2_xavier_fill(self.input_proj[-1])
         | 
| 127 | 
            +
                        else:
         | 
| 128 | 
            +
                            self.input_proj.append(nn.Sequential())
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    self.task_switch = task_switch
         | 
| 131 | 
            +
                    self.query_index = {}
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    # output FFNs
         | 
| 134 | 
            +
                    self.lang_encoder = lang_encoder
         | 
| 135 | 
            +
                    self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
         | 
| 136 | 
            +
                    self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
         | 
| 137 | 
            +
                    trunc_normal_(self.class_embed, std=.02)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    if task_switch['bbox']:
         | 
| 140 | 
            +
                        self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    if task_switch['spatial']:
         | 
| 143 | 
            +
                        # spatial query
         | 
| 144 | 
            +
                        self.mask_sptial_embed = nn.ParameterList([nn.Parameter(torch.empty(hidden_dim, hidden_dim)) for x in range(3)])
         | 
| 145 | 
            +
                        trunc_normal_(self.mask_sptial_embed[0], std=.02)
         | 
| 146 | 
            +
                        trunc_normal_(self.mask_sptial_embed[1], std=.02)
         | 
| 147 | 
            +
                        trunc_normal_(self.mask_sptial_embed[2], std=.02)
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                        self.max_spatial_len = max_spatial_len
         | 
| 150 | 
            +
                        # spatial memory
         | 
| 151 | 
            +
                        num_spatial_memories = attn_arch['SPATIAL_MEMORIES']
         | 
| 152 | 
            +
                        self.spatial_embed = nn.Embedding(num_spatial_memories, hidden_dim)
         | 
| 153 | 
            +
                        self.spatial_featured = nn.Embedding(num_spatial_memories, hidden_dim)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                        # learnable positive negative indicator
         | 
| 156 | 
            +
                        self.pn_indicator = nn.Embedding(2, hidden_dim)
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    # build AttentionDataStruct
         | 
| 159 | 
            +
                    attn_arch['NUM_LAYERS'] = self.num_layers
         | 
| 160 | 
            +
                    self.attention_data = AttentionDataStruct(attn_arch, task_switch)
         | 
| 161 | 
            +
                    self.sample_size = attn_arch['QUERY_NUMBER']
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                @classmethod
         | 
| 164 | 
            +
                def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra):
         | 
| 165 | 
            +
                    ret = {}
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                    ret["lang_encoder"] = lang_encoder
         | 
| 168 | 
            +
                    ret["in_channels"] = in_channels
         | 
| 169 | 
            +
                    ret["mask_classification"] = mask_classification
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    enc_cfg = cfg['MODEL']['ENCODER']
         | 
| 172 | 
            +
                    dec_cfg = cfg['MODEL']['DECODER']
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    ret["hidden_dim"] = dec_cfg['HIDDEN_DIM']
         | 
| 175 | 
            +
                    ret["dim_proj"] = cfg['MODEL']['DIM_PROJ']
         | 
| 176 | 
            +
                    ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES']
         | 
| 177 | 
            +
                    ret["contxt_len"] = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    # Transformer parameters:
         | 
| 180 | 
            +
                    ret["nheads"] = dec_cfg['NHEADS']
         | 
| 181 | 
            +
                    ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD']
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    # NOTE: because we add learnable query features which requires supervision,
         | 
| 184 | 
            +
                    # we add minus 1 to decoder layers to be consistent with our loss
         | 
| 185 | 
            +
                    # implementation: that is, number of auxiliary losses is always
         | 
| 186 | 
            +
                    # equal to number of decoder layers. With learnable query features, the number of
         | 
| 187 | 
            +
                    # auxiliary losses equals number of decoders plus 1.
         | 
| 188 | 
            +
                    assert dec_cfg['DEC_LAYERS'] >= 1
         | 
| 189 | 
            +
                    ret["dec_layers"] = dec_cfg['DEC_LAYERS'] - 1
         | 
| 190 | 
            +
                    ret["pre_norm"] = dec_cfg['PRE_NORM']
         | 
| 191 | 
            +
                    ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ']
         | 
| 192 | 
            +
                    ret["mask_dim"] = enc_cfg['MASK_DIM']
         | 
| 193 | 
            +
                    ret["task_switch"] = extra['task_switch']
         | 
| 194 | 
            +
                    ret["max_spatial_len"] = dec_cfg['MAX_SPATIAL_LEN']
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    # attn data struct
         | 
| 197 | 
            +
                    ret["attn_arch"] = cfg['ATTENTION_ARCH']
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    return ret
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                def forward(self, x, mask_features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
         | 
| 202 | 
            +
                    # x is a list of multi-scale feature
         | 
| 203 | 
            +
                    assert len(x) == self.num_feature_levels; del mask
         | 
| 204 | 
            +
                    spatial_extra_flag = 'spatial_query_pos_mask' in extra.keys() or task == 'refimg' or 'refimg_tokens' in extra
         | 
| 205 | 
            +
                    grounding_extra_flag = 'grounding_tokens' in extra.keys()
         | 
| 206 | 
            +
                    spatial_memory_flag = 'prev_mask' in extra.keys()
         | 
| 207 | 
            +
                    flags = {"spatial": spatial_extra_flag, "grounding": grounding_extra_flag, "memories_spatial": spatial_memory_flag}
         | 
| 208 | 
            +
                    self.attention_data.reset(flags, task, extra)
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    src, pos, size_list = prepare_features(x, self.num_feature_levels, self.pe_layer, self.input_proj, self.level_embed)
         | 
| 211 | 
            +
                    _,bs,_ = src[0].shape
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    # QxNxC
         | 
| 214 | 
            +
                    query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
         | 
| 215 | 
            +
                    output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
         | 
| 216 | 
            +
                    self.attention_data.set('queries_object', 'queries', output, query_embed)
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    if self.task_switch['spatial'] and spatial_extra_flag:
         | 
| 219 | 
            +
                        if 'refimg_tokens' not in extra:
         | 
| 220 | 
            +
                            # get divisor
         | 
| 221 | 
            +
                            c,h,w = extra['spatial_query_pos_mask'][0].shape
         | 
| 222 | 
            +
                            divisor = torch.tensor([1,h,w], device=output.device)[None,]
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                            # Get mean pos spatial query
         | 
| 225 | 
            +
                            non_zero_pos_point = [rand_sample(m, divisor, self.max_spatial_len[-1]).t() for m in extra['spatial_query_pos_mask']]
         | 
| 226 | 
            +
                            non_zero_pos_index = [m[:,0:1].long() for m in non_zero_pos_point]
         | 
| 227 | 
            +
                            non_zero_pos_point = nn.utils.rnn.pad_sequence(non_zero_pos_point, padding_value=-1).permute(1,0,2)
         | 
| 228 | 
            +
                            non_zero_pos_index = nn.utils.rnn.pad_sequence(non_zero_pos_index, padding_value=-1).permute(1,0,2)[:,:,0]
         | 
| 229 | 
            +
                            non_zero_pos_mask = (non_zero_pos_point.sum(dim=-1) < 0)
         | 
| 230 | 
            +
                            spatial_query_pos = point_sample(mask_features, non_zero_pos_point[:,:,1:].flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
         | 
| 231 | 
            +
                            num_mask_per_batch = [len(m) for m in extra['spatial_query_pos_mask']]
         | 
| 232 | 
            +
                            spatial_query_pos = nn.utils.rnn.pad_sequence([torch.stack([x[ns==n].mean(dim=0, keepdim=False) if (ns==n).sum() > 0 else -torch.ones((x.shape[1]), device=spatial_query_pos.device) for n in range(mb)]) for x, m, ns, mb in zip(spatial_query_pos.transpose(1,2), ~non_zero_pos_mask, non_zero_pos_index, num_mask_per_batch)], padding_value=-1).nan_to_num()
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                            # Get mean neg spatial query
         | 
| 235 | 
            +
                            non_zero_neg_point = [rand_sample(m, divisor, self.max_spatial_len[-1]).t() for m in extra['spatial_query_neg_mask']]
         | 
| 236 | 
            +
                            non_zero_neg_index = [m[:,0:1].long() for m in non_zero_neg_point]
         | 
| 237 | 
            +
                            non_zero_neg_point = nn.utils.rnn.pad_sequence(non_zero_neg_point, padding_value=-1).permute(1,0,2)
         | 
| 238 | 
            +
                            non_zero_neg_index = nn.utils.rnn.pad_sequence(non_zero_neg_index, padding_value=-1).permute(1,0,2)[:,:,0]
         | 
| 239 | 
            +
                            non_zero_neg_mask = (non_zero_neg_point.sum(dim=-1) < 0)
         | 
| 240 | 
            +
                            spatial_query_neg = point_sample(mask_features, non_zero_neg_point[:,:,1:].flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
         | 
| 241 | 
            +
                            num_mask_per_batch = [len(m) for m in extra['spatial_query_neg_mask']]
         | 
| 242 | 
            +
                            spatial_query_neg = nn.utils.rnn.pad_sequence([torch.stack([x[ns==n].mean(dim=0, keepdim=False) if (ns==n).sum() > 0 else -torch.ones((x.shape[1]), device=spatial_query_neg.device) for n in range(mb)]) for x, m, ns, mb in zip(spatial_query_neg.transpose(1,2), ~non_zero_neg_mask, non_zero_neg_index, num_mask_per_batch)], padding_value=-1).nan_to_num()
         | 
| 243 | 
            +
                            # Get layerwise spatial query
         | 
| 244 | 
            +
                            src_spatial_queries = []
         | 
| 245 | 
            +
                            src_spatial_maskings = []
         | 
| 246 | 
            +
                            src_spatial_indices = []
         | 
| 247 | 
            +
                            for i in range(len(src)):
         | 
| 248 | 
            +
                                hw,_,dc = src[i].shape
         | 
| 249 | 
            +
                                src_mask_features = src[i].view(size_list[i][0],size_list[i][1],bs,dc)
         | 
| 250 | 
            +
                                src_mask_features = src_mask_features @ self.mask_sptial_embed[i]
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                                non_zero_query_point_pos = [rand_sample(m, divisor, self.max_spatial_len[i]).t() for m in extra['spatial_query_pos_mask']]
         | 
| 253 | 
            +
                                non_zero_query_point_neg = [rand_sample(m, divisor, self.max_spatial_len[i]).t() for m in extra['spatial_query_neg_mask']]
         | 
| 254 | 
            +
                                non_zero_query_point = [torch.cat([x[:,1:],y[:,1:]], dim=0) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
         | 
| 255 | 
            +
                                non_zero_query_index = [torch.cat([x[:,0:1],y[:,0:1]], dim=0) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                                pos_neg_indicator = [torch.cat([torch.ones(x.shape[0], device=x.device), -torch.ones(y.shape[0], device=y.device)]) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
         | 
| 258 | 
            +
                                pos_neg_indicator = nn.utils.rnn.pad_sequence(pos_neg_indicator, padding_value=0)
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                                non_zero_query_point = nn.utils.rnn.pad_sequence(non_zero_query_point, padding_value=-1).permute(1,0,2)
         | 
| 261 | 
            +
                                non_zero_query_index = nn.utils.rnn.pad_sequence(non_zero_query_index, padding_value=-1).permute(1,0,2)
         | 
| 262 | 
            +
                                non_zero_query_mask = (non_zero_query_point.sum(dim=-1) < 0)
         | 
| 263 | 
            +
                                non_zero_query_point[non_zero_query_mask] = 0
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                                spatial_tokens = point_sample(src_mask_features.permute(2,3,0,1), non_zero_query_point.flip(dims=(2,)).type(src_mask_features.dtype), align_corners=True).permute(2,0,1)
         | 
| 266 | 
            +
                                spatial_tokens[pos_neg_indicator==1] += self.pn_indicator.weight[0:1]
         | 
| 267 | 
            +
                                spatial_tokens[pos_neg_indicator==-1] += self.pn_indicator.weight[1:2]
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                                src_spatial_queries += [spatial_tokens]
         | 
| 270 | 
            +
                                src_spatial_maskings += [non_zero_query_mask]
         | 
| 271 | 
            +
                                src_spatial_indices += [non_zero_query_index]
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                            if 'refimg' in task:
         | 
| 274 | 
            +
                                output_refimg = {}
         | 
| 275 | 
            +
                                output_refimg['spatial_query_pos'] = spatial_query_pos
         | 
| 276 | 
            +
                                output_refimg['spatial_query_neg'] = spatial_query_neg
         | 
| 277 | 
            +
                                output_refimg['src_spatial_queries'] = src_spatial_queries
         | 
| 278 | 
            +
                                output_refimg['src_spatial_maskings'] = src_spatial_maskings
         | 
| 279 | 
            +
                                return output_refimg
         | 
| 280 | 
            +
                        else:
         | 
| 281 | 
            +
                            spatial_query_pos = extra['refimg_tokens']['spatial_query_pos']
         | 
| 282 | 
            +
                            spatial_query_neg = extra['refimg_tokens']['spatial_query_neg']
         | 
| 283 | 
            +
                            src_spatial_queries = extra['refimg_tokens']['src_spatial_queries']
         | 
| 284 | 
            +
                            src_spatial_maskings = extra['refimg_tokens']['src_spatial_maskings']
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                        # Get object query for spatial index
         | 
| 287 | 
            +
                        self.attention_data.set_extra({"spatial_query_number": len(spatial_query_pos), "sample_size": self.sample_size})
         | 
| 288 | 
            +
                        self.attention_data.set('queries_spatial', 'queries', sample_size=self.sample_size*len(spatial_query_pos))
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                        # set spatial memory
         | 
| 291 | 
            +
                        spatial_output = self.spatial_featured.weight.unsqueeze(1).repeat(1, bs, 1)
         | 
| 292 | 
            +
                        spatial_embed = self.spatial_embed.weight.unsqueeze(1).repeat(1, bs, 1)
         | 
| 293 | 
            +
                        self.attention_data.set('memories_spatial', 'memories', spatial_output, spatial_embed)
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                    if self.task_switch['grounding'] and grounding_extra_flag:
         | 
| 296 | 
            +
                        # Get grounding tokens
         | 
| 297 | 
            +
                        grounding_tokens = extra['grounding_tokens']
         | 
| 298 | 
            +
                        _grounding_tokens = grounding_tokens.detach().clone()
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                        self.attention_data.set('tokens_grounding', 'tokens', grounding_tokens, _grounding_tokens)
         | 
| 301 | 
            +
                        self.attention_data.set('queries_grounding', 'queries')
         | 
| 302 | 
            +
                        self.attention_data.set_maskings('tokens_grounding', extra['grounding_nonzero_mask'])
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                    output, query_embed = self.attention_data.cross_attn_variables()
         | 
| 305 | 
            +
                    # prediction heads on learnable query features
         | 
| 306 | 
            +
                    results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0])
         | 
| 307 | 
            +
                    results["predictions_pos_spatial"] = spatial_query_pos.transpose(0,1) if spatial_extra_flag else None
         | 
| 308 | 
            +
                    results["predictions_neg_spatial"] = spatial_query_neg.transpose(0,1) if spatial_extra_flag else None
         | 
| 309 | 
            +
                    self.attention_data.set_results(results)
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                    for i in range(self.num_layers):
         | 
| 312 | 
            +
                        level_index = i % self.num_feature_levels
         | 
| 313 | 
            +
                        # CROSS ATTENTION
         | 
| 314 | 
            +
                        output, avg_attn = self.transformer_cross_attention_layers[i](
         | 
| 315 | 
            +
                            output, src[level_index],
         | 
| 316 | 
            +
                            memory_mask=self.attention_data.cross_attn_mask(size_list[level_index], self.num_heads),
         | 
| 317 | 
            +
                            memory_key_padding_mask=None,  # here we do not apply masking on padded region
         | 
| 318 | 
            +
                            pos=pos[level_index], query_pos=query_embed
         | 
| 319 | 
            +
                        )
         | 
| 320 | 
            +
                        self.attention_data.update_variables(output, 'cross_attn')
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                        # SELF ATTENTION
         | 
| 323 | 
            +
                        self_attn_mask = torch.zeros((bs, self.num_queries, self.num_queries), device=query_embed.device).bool() # Default False (attend oq)
         | 
| 324 | 
            +
                        if self.task_switch['spatial'] and spatial_extra_flag:
         | 
| 325 | 
            +
                            # get spatial tokens
         | 
| 326 | 
            +
                            spatial_tokens = src_spatial_queries[level_index]
         | 
| 327 | 
            +
                            _spatial_tokens = spatial_tokens.detach().clone()
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                            self.attention_data.set('tokens_spatial', 'tokens', spatial_tokens, _spatial_tokens)
         | 
| 330 | 
            +
                            self.attention_data.set_maskings('tokens_spatial', src_spatial_maskings[level_index])
         | 
| 331 | 
            +
                            self.attention_data.set_extra({"spatial_indices": src_spatial_indices[level_index]})
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                        output, query_embed, self_attn_mask = self.attention_data.self_attn(bs, self.num_heads)
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                        output = self.transformer_self_attention_layers[i](
         | 
| 336 | 
            +
                            output, tgt_mask=self_attn_mask,
         | 
| 337 | 
            +
                            tgt_key_padding_mask=None,
         | 
| 338 | 
            +
                            query_pos=query_embed)
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                        # FFN
         | 
| 341 | 
            +
                        output = self.transformer_ffn_layers[i](
         | 
| 342 | 
            +
                            output
         | 
| 343 | 
            +
                        )
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                        self.attention_data.update_variables(output, 'self_attn')
         | 
| 346 | 
            +
                        output, query_embed = self.attention_data.cross_attn_variables()
         | 
| 347 | 
            +
                        results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i)
         | 
| 348 | 
            +
                        results["predictions_pos_spatial"] = spatial_query_pos.transpose(0,1) if spatial_extra_flag else None
         | 
| 349 | 
            +
                        results["predictions_neg_spatial"] = spatial_query_neg.transpose(0,1) if spatial_extra_flag else None
         | 
| 350 | 
            +
                        self.attention_data.set_results(results)
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                    return self.attention_data.organize_output()
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                def forward_prediction_heads(self, output, mask_features, attn_mask_target_size, layer_id=-1):
         | 
| 355 | 
            +
                    decoder_output = self.decoder_norm(output)
         | 
| 356 | 
            +
                    decoder_output = decoder_output.transpose(0, 1)
         | 
| 357 | 
            +
                    class_embed = decoder_output @ self.class_embed
         | 
| 358 | 
            +
                    outputs_class = self.lang_encoder.compute_similarity(class_embed)
         | 
| 359 | 
            +
                    mask_embed = self.mask_embed(decoder_output)
         | 
| 360 | 
            +
                    outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
         | 
| 361 | 
            +
                    
         | 
| 362 | 
            +
                    outputs_bbox = [None for i in range(len(outputs_mask))]
         | 
| 363 | 
            +
                    if self.task_switch['bbox']:
         | 
| 364 | 
            +
                        outputs_bbox = self.bbox_embed(decoder_output)
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                    # NOTE: prediction is of higher-resolution
         | 
| 367 | 
            +
                    # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
         | 
| 368 | 
            +
                    attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False)
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                    # must use bool type
         | 
| 371 | 
            +
                    # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
         | 
| 372 | 
            +
                    attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
         | 
| 373 | 
            +
                    attn_mask = attn_mask.detach()
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                    outputs_caption = class_embed
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                    results = {
         | 
| 378 | 
            +
                        "attn_mask": attn_mask,
         | 
| 379 | 
            +
                        "predictions_class": outputs_class,
         | 
| 380 | 
            +
                        "predictions_mask": outputs_mask,
         | 
| 381 | 
            +
                        "predictions_bbox": outputs_bbox,
         | 
| 382 | 
            +
                        "predictions_caption": outputs_caption,
         | 
| 383 | 
            +
                        "predictions_maskemb": mask_embed,
         | 
| 384 | 
            +
                    }
         | 
| 385 | 
            +
                    return results
         | 
| 386 | 
            +
             | 
| 387 | 
            +
            @register_decoder
         | 
| 388 | 
            +
            def get_seem_interface(cfg, in_channels, lang_encoder, mask_classification, extra):
         | 
| 389 | 
            +
                return SEEMDecoder(cfg, in_channels, lang_encoder, mask_classification, extra)
         | 
    	
        modeling/interface/xdecoder.py
    ADDED
    
    | @@ -0,0 +1,497 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # --------------------------------------------------------
         | 
| 2 | 
            +
            # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
         | 
| 3 | 
            +
            # Copyright (c) 2022 Microsoft
         | 
| 4 | 
            +
            # Licensed under The MIT License [see LICENSE for details]
         | 
| 5 | 
            +
            # Written by Xueyan Zou (xueyan@cs.wisc.edu)
         | 
| 6 | 
            +
            # --------------------------------------------------------
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import logging
         | 
| 9 | 
            +
            from typing import Optional
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import torch
         | 
| 12 | 
            +
            from torch import nn, Tensor
         | 
| 13 | 
            +
            from torch.nn import functional as F
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from timm.models.layers import trunc_normal_
         | 
| 16 | 
            +
            from detectron2.layers import Conv2d
         | 
| 17 | 
            +
            import fvcore.nn.weight_init as weight_init
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from .build import register_decoder
         | 
| 20 | 
            +
            from .modules import SelfAttentionLayer, CrossAttentionLayer, FFNLayer, MLP
         | 
| 21 | 
            +
            from ..utils import configurable
         | 
| 22 | 
            +
            from ..modules import PositionEmbeddingSine
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            class XDecoder(nn.Module):
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                @configurable
         | 
| 28 | 
            +
                def __init__(
         | 
| 29 | 
            +
                    self,
         | 
| 30 | 
            +
                    lang_encoder: nn.Module,
         | 
| 31 | 
            +
                    in_channels,
         | 
| 32 | 
            +
                    mask_classification=True,
         | 
| 33 | 
            +
                    *,
         | 
| 34 | 
            +
                    hidden_dim: int,
         | 
| 35 | 
            +
                    dim_proj: int,
         | 
| 36 | 
            +
                    num_queries: int,
         | 
| 37 | 
            +
                    contxt_len: int,
         | 
| 38 | 
            +
                    nheads: int,
         | 
| 39 | 
            +
                    dim_feedforward: int,
         | 
| 40 | 
            +
                    dec_layers: int,
         | 
| 41 | 
            +
                    pre_norm: bool,
         | 
| 42 | 
            +
                    mask_dim: int,
         | 
| 43 | 
            +
                    task_switch: dict,
         | 
| 44 | 
            +
                    captioning_step: int,
         | 
| 45 | 
            +
                    enforce_input_project: bool,
         | 
| 46 | 
            +
                ):
         | 
| 47 | 
            +
                    """
         | 
| 48 | 
            +
                    NOTE: this interface is experimental.
         | 
| 49 | 
            +
                    Args:
         | 
| 50 | 
            +
                        in_channels: channels of the input features
         | 
| 51 | 
            +
                        mask_classification: whether to add mask classifier or not
         | 
| 52 | 
            +
                        num_classes: number of classes
         | 
| 53 | 
            +
                        hidden_dim: Transformer feature dimension
         | 
| 54 | 
            +
                        num_queries: number of queries
         | 
| 55 | 
            +
                        nheads: number of heads
         | 
| 56 | 
            +
                        dim_feedforward: feature dimension in feedforward network
         | 
| 57 | 
            +
                        enc_layers: number of Transformer encoder layers
         | 
| 58 | 
            +
                        dec_layers: number of Transformer decoder layers
         | 
| 59 | 
            +
                        pre_norm: whether to use pre-LayerNorm or not
         | 
| 60 | 
            +
                        mask_dim: mask feature dimension
         | 
| 61 | 
            +
                        enforce_input_project: add input project 1x1 conv even if input
         | 
| 62 | 
            +
                            channels and hidden dim is identical
         | 
| 63 | 
            +
                    """
         | 
| 64 | 
            +
                    super().__init__()
         | 
| 65 | 
            +
                    assert mask_classification, "Only support mask classification model"
         | 
| 66 | 
            +
                    self.mask_classification = mask_classification
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    # positional encoding
         | 
| 69 | 
            +
                    N_steps = hidden_dim // 2
         | 
| 70 | 
            +
                    self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
         | 
| 71 | 
            +
                    
         | 
| 72 | 
            +
                    # define Transformer decoder here
         | 
| 73 | 
            +
                    self.num_heads = nheads
         | 
| 74 | 
            +
                    self.num_layers = dec_layers
         | 
| 75 | 
            +
                    self.contxt_len = contxt_len
         | 
| 76 | 
            +
                    self.transformer_self_attention_layers = nn.ModuleList()
         | 
| 77 | 
            +
                    self.transformer_cross_attention_layers = nn.ModuleList()
         | 
| 78 | 
            +
                    self.transformer_ffn_layers = nn.ModuleList()
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    for _ in range(self.num_layers):
         | 
| 81 | 
            +
                        self.transformer_self_attention_layers.append(
         | 
| 82 | 
            +
                            SelfAttentionLayer(
         | 
| 83 | 
            +
                                d_model=hidden_dim,
         | 
| 84 | 
            +
                                nhead=nheads,
         | 
| 85 | 
            +
                                dropout=0.0,
         | 
| 86 | 
            +
                                normalize_before=pre_norm,
         | 
| 87 | 
            +
                            )
         | 
| 88 | 
            +
                        )
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                        self.transformer_cross_attention_layers.append(
         | 
| 91 | 
            +
                            CrossAttentionLayer(
         | 
| 92 | 
            +
                                d_model=hidden_dim,
         | 
| 93 | 
            +
                                nhead=nheads,
         | 
| 94 | 
            +
                                dropout=0.0,
         | 
| 95 | 
            +
                                normalize_before=pre_norm,
         | 
| 96 | 
            +
                            )
         | 
| 97 | 
            +
                        )
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                        self.transformer_ffn_layers.append(
         | 
| 100 | 
            +
                            FFNLayer(
         | 
| 101 | 
            +
                                d_model=hidden_dim,
         | 
| 102 | 
            +
                                dim_feedforward=dim_feedforward,
         | 
| 103 | 
            +
                                dropout=0.0,
         | 
| 104 | 
            +
                                normalize_before=pre_norm,
         | 
| 105 | 
            +
                            )
         | 
| 106 | 
            +
                        )
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    self.decoder_norm = nn.LayerNorm(hidden_dim)
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    self.num_queries = num_queries
         | 
| 111 | 
            +
                    # learnable query features
         | 
| 112 | 
            +
                    self.query_feat = nn.Embedding(num_queries, hidden_dim)
         | 
| 113 | 
            +
                    # learnable query p.e.
         | 
| 114 | 
            +
                    self.query_embed = nn.Embedding(num_queries, hidden_dim)
         | 
| 115 | 
            +
                    
         | 
| 116 | 
            +
                    # level embedding (we always use 3 scales)
         | 
| 117 | 
            +
                    self.num_feature_levels = 3
         | 
| 118 | 
            +
                    self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
         | 
| 119 | 
            +
                    self.input_proj = nn.ModuleList()
         | 
| 120 | 
            +
                    
         | 
| 121 | 
            +
                    for _ in range(self.num_feature_levels):
         | 
| 122 | 
            +
                        if in_channels != hidden_dim or enforce_input_project:
         | 
| 123 | 
            +
                            self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
         | 
| 124 | 
            +
                            weight_init.c2_xavier_fill(self.input_proj[-1])
         | 
| 125 | 
            +
                        else:
         | 
| 126 | 
            +
                            self.input_proj.append(nn.Sequential())
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    self.task_switch = task_switch
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    # output FFNs
         | 
| 131 | 
            +
                    self.lang_encoder = lang_encoder
         | 
| 132 | 
            +
                    if self.task_switch['mask']:
         | 
| 133 | 
            +
                        self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
         | 
| 136 | 
            +
                    trunc_normal_(self.class_embed, std=.02)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    if task_switch['bbox']:
         | 
| 139 | 
            +
                        self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    # Caption Project and query
         | 
| 142 | 
            +
                    if task_switch['captioning']:
         | 
| 143 | 
            +
                        self.caping_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
         | 
| 144 | 
            +
                        trunc_normal_(self.caping_embed, std=.02)
         | 
| 145 | 
            +
                        self.pos_embed_caping = nn.Embedding(contxt_len, hidden_dim)
         | 
| 146 | 
            +
                        self.captioning_step = captioning_step
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    # register self_attn_mask to avoid information leakage, it includes interaction between object query, class query and caping query
         | 
| 149 | 
            +
                    self_attn_mask = torch.zeros((1, num_queries + contxt_len, num_queries + contxt_len)).bool()
         | 
| 150 | 
            +
                    self_attn_mask[:, :num_queries, num_queries:] = True # object+class query does not attend with caption query.
         | 
| 151 | 
            +
                    self_attn_mask[:, num_queries:, num_queries:] = torch.triu(torch.ones((1, contxt_len, contxt_len)), diagonal=1).bool() # caption query only attend with previous token.
         | 
| 152 | 
            +
                    self_attn_mask[:, :num_queries-1, num_queries-1:num_queries] = True # object query does not attend with class query.
         | 
| 153 | 
            +
                    self_attn_mask[:, num_queries-1:num_queries, :num_queries-1] = True # class query does not attend with object query.
         | 
| 154 | 
            +
                    self.register_buffer("self_attn_mask", self_attn_mask)
         | 
| 155 | 
            +
             | 
| 156 | 
            +
             | 
| 157 | 
            +
                @classmethod
         | 
| 158 | 
            +
                def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra):
         | 
| 159 | 
            +
                    ret = {}
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    ret["lang_encoder"] = lang_encoder
         | 
| 162 | 
            +
                    ret["in_channels"] = in_channels
         | 
| 163 | 
            +
                    ret["mask_classification"] = mask_classification
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    enc_cfg = cfg['MODEL']['ENCODER']
         | 
| 166 | 
            +
                    dec_cfg = cfg['MODEL']['DECODER']
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    ret["hidden_dim"] = dec_cfg['HIDDEN_DIM']
         | 
| 169 | 
            +
                    ret["dim_proj"] = cfg['MODEL']['DIM_PROJ']
         | 
| 170 | 
            +
                    ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES']
         | 
| 171 | 
            +
                    ret["contxt_len"] = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    # Transformer parameters:
         | 
| 174 | 
            +
                    ret["nheads"] = dec_cfg['NHEADS']
         | 
| 175 | 
            +
                    ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD']
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    # NOTE: because we add learnable query features which requires supervision,
         | 
| 178 | 
            +
                    # we add minus 1 to decoder layers to be consistent with our loss
         | 
| 179 | 
            +
                    # implementation: that is, number of auxiliary losses is always
         | 
| 180 | 
            +
                    # equal to number of decoder layers. With learnable query features, the number of
         | 
| 181 | 
            +
                    # auxiliary losses equals number of decoders plus 1.
         | 
| 182 | 
            +
                    assert dec_cfg['DEC_LAYERS'] >= 1
         | 
| 183 | 
            +
                    ret["dec_layers"] = dec_cfg['DEC_LAYERS'] - 1
         | 
| 184 | 
            +
                    ret["pre_norm"] = dec_cfg['PRE_NORM']
         | 
| 185 | 
            +
                    ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ']
         | 
| 186 | 
            +
                    ret["mask_dim"] = enc_cfg['MASK_DIM']
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    ret["task_switch"] = extra['task_switch']
         | 
| 189 | 
            +
                    ret["captioning_step"] = dec_cfg['CAPTIONING'].get('STEP', 50)
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    return ret
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                def forward(self, x, mask_features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
         | 
| 194 | 
            +
                    if task == 'captioning_infer':
         | 
| 195 | 
            +
                        return self.forward_captioning(x, mask_features, mask=mask, target_queries=target_queries, target_vlp=target_vlp, task=task, extra=extra)
         | 
| 196 | 
            +
                    # x is a list of multi-scale feature
         | 
| 197 | 
            +
                    assert len(x) == self.num_feature_levels
         | 
| 198 | 
            +
                    src = []
         | 
| 199 | 
            +
                    pos = []
         | 
| 200 | 
            +
                    size_list = []
         | 
| 201 | 
            +
                    
         | 
| 202 | 
            +
                    # disable mask, it does not affect performance
         | 
| 203 | 
            +
                    del mask
         | 
| 204 | 
            +
                    for i in range(self.num_feature_levels):
         | 
| 205 | 
            +
                        size_list.append(x[i].shape[-2:])
         | 
| 206 | 
            +
                        pos.append(self.pe_layer(x[i], None).flatten(2))
         | 
| 207 | 
            +
                        src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                        # flatten NxCxHxW to HWxNxC
         | 
| 210 | 
            +
                        pos[-1] = pos[-1].permute(2, 0, 1)
         | 
| 211 | 
            +
                        src[-1] = src[-1].permute(2, 0, 1)
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    _, bs, _ = src[0].shape
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    # QxNxC
         | 
| 216 | 
            +
                    query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
         | 
| 217 | 
            +
                    output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    predictions_class = []
         | 
| 220 | 
            +
                    predictions_mask = []
         | 
| 221 | 
            +
                    predictions_bbox = []
         | 
| 222 | 
            +
                    predictions_caption = []
         | 
| 223 | 
            +
                    predictions_captioning = []
         | 
| 224 | 
            +
                    
         | 
| 225 | 
            +
                    self_tgt_mask = None
         | 
| 226 | 
            +
                    if self.training and task == 'vlp' and self.task_switch['captioning']:
         | 
| 227 | 
            +
                        # output = torch.cat((output, self.query_feat_caping.weight.unsqueeze(1).repeat(1, bs, 1)), dim=0) # concat object query, class token and caption token.
         | 
| 228 | 
            +
                        caping_lang_embed = torch.cat([caption['caption_tokens'] for caption in target_vlp], dim=0).transpose(0, 1) # language output
         | 
| 229 | 
            +
                        _caping_lang_embed = caping_lang_embed.detach().clone()
         | 
| 230 | 
            +
                        output = torch.cat((output, _caping_lang_embed), dim=0) # concat object query, class token and caption token.
         | 
| 231 | 
            +
                        caping_lang_embed += self.pos_embed_caping.weight.unsqueeze(1).repeat(1, bs, 1)
         | 
| 232 | 
            +
                        query_embed = torch.cat((query_embed, caping_lang_embed), dim=0) # may not add at the beginning.
         | 
| 233 | 
            +
                        self_tgt_mask = self.self_attn_mask.repeat(output.shape[1]*self.num_heads, 1, 1)
         | 
| 234 | 
            +
                    elif (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']):
         | 
| 235 | 
            +
                        self_tgt_mask = self.self_attn_mask[:,:self.num_queries,:self.num_queries].repeat(output.shape[1]*self.num_heads, 1, 1)
         | 
| 236 | 
            +
                        grounding_tokens = extra['grounding_tokens']
         | 
| 237 | 
            +
                        _grounding_tokens = grounding_tokens.detach().clone()
         | 
| 238 | 
            +
                        # initialize with negative attention at the beginning.
         | 
| 239 | 
            +
                        pad_tgt_mask = torch.ones((1, self.num_queries + (self.num_queries-1) + len(grounding_tokens), self.num_queries + (self.num_queries-1) + len(grounding_tokens)), device=self_tgt_mask.device).bool().repeat(output.shape[1]*self.num_heads, 1, 1)
         | 
| 240 | 
            +
                        pad_tgt_mask[:,:self.num_queries,:self.num_queries] = self_tgt_mask
         | 
| 241 | 
            +
                        pad_tgt_mask[:,self.num_queries:,self.num_queries:] = False # grounding tokens could attend with eatch other
         | 
| 242 | 
            +
                        self_tgt_mask = pad_tgt_mask
         | 
| 243 | 
            +
                        output = torch.cat((output, output[:-1]), dim=0)
         | 
| 244 | 
            +
                        query_embed = torch.cat((query_embed, query_embed[:-1]), dim=0) # also pad language embdding to fix embedding
         | 
| 245 | 
            +
                    else:
         | 
| 246 | 
            +
                        self_tgt_mask = self.self_attn_mask[:,:self.num_queries,:self.num_queries].repeat(output.shape[1]*self.num_heads, 1, 1)
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    # prediction heads on learnable query features
         | 
| 249 | 
            +
                    results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0], task=task)
         | 
| 250 | 
            +
                    attn_mask = results["attn_mask"]
         | 
| 251 | 
            +
                    predictions_class.append(results["outputs_class"])
         | 
| 252 | 
            +
                    predictions_mask.append(results["outputs_mask"])
         | 
| 253 | 
            +
                    predictions_bbox.append(results["outputs_bbox"])
         | 
| 254 | 
            +
                    predictions_caption.append(results["outputs_caption"])
         | 
| 255 | 
            +
                    predictions_captioning.append(results["outputs_captionting"])
         | 
| 256 | 
            +
                    
         | 
| 257 | 
            +
                    for i in range(self.num_layers):
         | 
| 258 | 
            +
                        level_index = i % self.num_feature_levels
         | 
| 259 | 
            +
                        attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                        if self.training and task == 'vlp' and self.task_switch['captioning']:
         | 
| 262 | 
            +
                            attn_mask = torch.cat((attn_mask, torch.zeros_like(attn_mask[:, :self.contxt_len, :])), dim=1)
         | 
| 263 | 
            +
                        # attention: cross-attention first
         | 
| 264 | 
            +
                        output, avg_attn = self.transformer_cross_attention_layers[i](
         | 
| 265 | 
            +
                            output, src[level_index],
         | 
| 266 | 
            +
                            memory_mask=attn_mask,
         | 
| 267 | 
            +
                            memory_key_padding_mask=None,  # here we do not apply masking on padded region
         | 
| 268 | 
            +
                            pos=pos[level_index], query_pos=query_embed
         | 
| 269 | 
            +
                        )
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                        if (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']):
         | 
| 272 | 
            +
                            output = torch.cat((output, _grounding_tokens), dim=0)
         | 
| 273 | 
            +
                            query_embed = torch.cat((query_embed, grounding_tokens), dim=0)
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                        output = self.transformer_self_attention_layers[i](
         | 
| 276 | 
            +
                            output, tgt_mask=self_tgt_mask,
         | 
| 277 | 
            +
                            tgt_key_padding_mask=None,
         | 
| 278 | 
            +
                            query_pos=query_embed
         | 
| 279 | 
            +
                        )
         | 
| 280 | 
            +
                        
         | 
| 281 | 
            +
                        # FFN
         | 
| 282 | 
            +
                        output = self.transformer_ffn_layers[i](
         | 
| 283 | 
            +
                            output
         | 
| 284 | 
            +
                        )
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                        if ((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']:
         | 
| 287 | 
            +
                            _grounding_tokens = output[-len(_grounding_tokens):]
         | 
| 288 | 
            +
                            output = output[:-len(_grounding_tokens)]
         | 
| 289 | 
            +
                            query_embed = query_embed[:-len(_grounding_tokens)]
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                        results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i, task=task)
         | 
| 292 | 
            +
                        attn_mask = results["attn_mask"]
         | 
| 293 | 
            +
                        predictions_class.append(results["outputs_class"])
         | 
| 294 | 
            +
                        predictions_mask.append(results["outputs_mask"])
         | 
| 295 | 
            +
                        predictions_bbox.append(results["outputs_bbox"])
         | 
| 296 | 
            +
                        predictions_caption.append(results["outputs_caption"])
         | 
| 297 | 
            +
                        predictions_captioning.append(results["outputs_captionting"])
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                    assert len(predictions_class) == self.num_layers + 1
         | 
| 300 | 
            +
                    if task == 'vlp':
         | 
| 301 | 
            +
                        out = {'pred_captionings': predictions_captioning[-1], 
         | 
| 302 | 
            +
                               'pred_captions': predictions_caption[-1], 
         | 
| 303 | 
            +
                               'aux_outputs': [{'pred_captionings': x, 'pred_captions': y } for x, y in zip(predictions_captioning[:-1], predictions_caption[:-1])]}
         | 
| 304 | 
            +
                        return out
         | 
| 305 | 
            +
                    else:
         | 
| 306 | 
            +
                        out = {
         | 
| 307 | 
            +
                            'pred_logits': predictions_class[-1],
         | 
| 308 | 
            +
                            'pred_masks': predictions_mask[-1],
         | 
| 309 | 
            +
                            'pred_boxes': predictions_bbox[-1],
         | 
| 310 | 
            +
                            'pred_captions': predictions_caption[-1],
         | 
| 311 | 
            +
                            'aux_outputs': self._set_aux_loss(
         | 
| 312 | 
            +
                                predictions_class if self.mask_classification else None, predictions_mask, predictions_bbox, predictions_caption
         | 
| 313 | 
            +
                            )
         | 
| 314 | 
            +
                        }
         | 
| 315 | 
            +
                        return out
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                def forward_captioning(self, x, mask_features, mask = None, target_queries = None, target_vlp = None, task='seg', extra={}):
         | 
| 318 | 
            +
                    # x is a list of multi-scale feature
         | 
| 319 | 
            +
                    assert len(x) == self.num_feature_levels
         | 
| 320 | 
            +
                    src = []
         | 
| 321 | 
            +
                    pos = []
         | 
| 322 | 
            +
                    size_list = []
         | 
| 323 | 
            +
                    
         | 
| 324 | 
            +
                    # disable mask, it does not affect performance
         | 
| 325 | 
            +
                    del mask
         | 
| 326 | 
            +
                    for i in range(self.num_feature_levels):
         | 
| 327 | 
            +
                        size_list.append(x[i].shape[-2:])
         | 
| 328 | 
            +
                        pos.append(self.pe_layer(x[i], None).flatten(2))
         | 
| 329 | 
            +
                        src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                        # flatten NxCxHxW to HWxNxC
         | 
| 332 | 
            +
                        pos[-1] = pos[-1].permute(2, 0, 1)
         | 
| 333 | 
            +
                        src[-1] = src[-1].permute(2, 0, 1)
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                    _, bs, _ = src[0].shape
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                    # QxNxC
         | 
| 338 | 
            +
                    query_embed_ = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
         | 
| 339 | 
            +
                    query_feat = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)        
         | 
| 340 | 
            +
                    caping_lang_token = extra['start_token'].repeat(bs, 1)
         | 
| 341 | 
            +
                    pos_embed_caping = self.pos_embed_caping.weight.unsqueeze(1).repeat(1, bs, 1)
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                    # prepare token embedding for evaluation
         | 
| 344 | 
            +
                    token_embs = self.lang_encoder.lang_encoder.token_embedding.weight
         | 
| 345 | 
            +
                    # token_embs = (token_embs / token_embs.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 346 | 
            +
                    
         | 
| 347 | 
            +
                    for cap_idx in range(0, self.captioning_step):
         | 
| 348 | 
            +
                        caping_lang_embed = self.lang_encoder.forward_language_token((caping_lang_token,))[0].transpose(0, 1)
         | 
| 349 | 
            +
                        output = torch.cat((query_feat, caping_lang_embed), dim=0) # concat object query, class token and caption token.
         | 
| 350 | 
            +
                        caping_lang_embed += pos_embed_caping
         | 
| 351 | 
            +
                        query_embed = torch.cat((query_embed_, caping_lang_embed), dim=0) # may not add at the beginning.
         | 
| 352 | 
            +
                        # output = torch.cat((query_feat, query_feat_caping), dim=0) # concat object query, class token and caption token.
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                        # prediction heads on learnable query features
         | 
| 355 | 
            +
                        results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0], task=task)
         | 
| 356 | 
            +
                        attn_mask = results["attn_mask"]
         | 
| 357 | 
            +
                    
         | 
| 358 | 
            +
                        for i in range(self.num_layers):
         | 
| 359 | 
            +
                            level_index = i % self.num_feature_levels
         | 
| 360 | 
            +
                            attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
         | 
| 361 | 
            +
                            attn_mask = torch.cat((attn_mask, torch.zeros_like(attn_mask[:, :self.contxt_len, :])), dim=1)
         | 
| 362 | 
            +
                            self_tgt_mask = self.self_attn_mask.repeat(output.shape[1]*self.num_heads, 1, 1)
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                            if extra['captioning_mask'] is not None:
         | 
| 365 | 
            +
                                bs,nq,wh = attn_mask.shape
         | 
| 366 | 
            +
                                assert bs==self.num_heads, "Only support single image referring captioning."
         | 
| 367 | 
            +
                                cap_mask = extra['captioning_mask']
         | 
| 368 | 
            +
                                attn_mask = attn_mask.reshape(bs,nq,size_list[i%3][0],size_list[i%3][1])
         | 
| 369 | 
            +
                                cap_mask = F.interpolate(cap_mask[None,].float(), size_list[i%3], mode='nearest').bool()[0,0]
         | 
| 370 | 
            +
                                attn_mask[:,self.num_queries:, cap_mask] = True
         | 
| 371 | 
            +
                                attn_mask = attn_mask.reshape(bs,nq,wh)
         | 
| 372 | 
            +
                            
         | 
| 373 | 
            +
                            # attention: cross-attention first
         | 
| 374 | 
            +
                            output, avg_attn = self.transformer_cross_attention_layers[i](
         | 
| 375 | 
            +
                                output, src[level_index],
         | 
| 376 | 
            +
                                memory_mask=attn_mask,
         | 
| 377 | 
            +
                                memory_key_padding_mask=None,  # here we do not apply masking on padded region
         | 
| 378 | 
            +
                                pos=pos[level_index], query_pos=query_embed
         | 
| 379 | 
            +
                            )
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                            output = self.transformer_self_attention_layers[i](
         | 
| 382 | 
            +
                                output, tgt_mask=self_tgt_mask,
         | 
| 383 | 
            +
                                tgt_key_padding_mask=None,
         | 
| 384 | 
            +
                                query_pos=query_embed
         | 
| 385 | 
            +
                            )
         | 
| 386 | 
            +
                            
         | 
| 387 | 
            +
                            # FFN
         | 
| 388 | 
            +
                            output = self.transformer_ffn_layers[i](
         | 
| 389 | 
            +
                                output
         | 
| 390 | 
            +
                            )
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                            results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i, task=task)
         | 
| 393 | 
            +
                            attn_mask = results["attn_mask"]
         | 
| 394 | 
            +
                        
         | 
| 395 | 
            +
                        pred_captions_gen = results['outputs_captionting']
         | 
| 396 | 
            +
                        # pred_captions_gen = (pred_captions_gen / pred_captions_gen.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 397 | 
            +
                        pred_captions_gen = pred_captions_gen @ token_embs.t()
         | 
| 398 | 
            +
                        caping_lang_token[:,cap_idx+1] = pred_captions_gen[:,cap_idx].max(-1)[1]
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                    texts = self.lang_encoder.tokenizer.batch_decode(caping_lang_token, skip_special_tokens=False)
         | 
| 401 | 
            +
                    texts_new = []
         | 
| 402 | 
            +
                    
         | 
| 403 | 
            +
                    for x in texts:
         | 
| 404 | 
            +
                        x = x.split('<|endoftext|>')[0]
         | 
| 405 | 
            +
                        x = x.replace('<|endoftext|>','')
         | 
| 406 | 
            +
                        x = x.replace('<|startoftext|>','')
         | 
| 407 | 
            +
                        x = x.strip()
         | 
| 408 | 
            +
                        texts_new.append(x)
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                    out = {'pred_captionings': caping_lang_token,
         | 
| 411 | 
            +
                           'pred_texts': texts_new}
         | 
| 412 | 
            +
                    return out
         | 
| 413 | 
            +
             | 
| 414 | 
            +
             | 
| 415 | 
            +
                def forward_prediction_heads(self, output, mask_features, attn_mask_target_size, layer_id=-1, task='seg'):
         | 
| 416 | 
            +
                    decoder_output = self.decoder_norm(output)
         | 
| 417 | 
            +
                    decoder_output = decoder_output.transpose(0, 1)
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                    # extract image captioning token from decoder output.
         | 
| 420 | 
            +
                    if self.task_switch['captioning'] and (task == 'vlp' or task == 'captioning_infer'):
         | 
| 421 | 
            +
                        outputs_captionting = decoder_output[:,self.num_queries:] @ self.caping_embed
         | 
| 422 | 
            +
                    else:
         | 
| 423 | 
            +
                        outputs_captionting = None
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                    # recompute class token output.
         | 
| 426 | 
            +
                    norm_decoder_output = decoder_output / (decoder_output.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 427 | 
            +
                    obj_token = norm_decoder_output[:,:self.num_queries-1]
         | 
| 428 | 
            +
                    cls_token = norm_decoder_output[:,self.num_queries-1:self.num_queries]
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                    sim = (cls_token @ obj_token.transpose(1,2)).softmax(-1)[:,0,:,None] # TODO include class token.
         | 
| 431 | 
            +
                    cls_token = (sim * decoder_output[:,:self.num_queries-1]).sum(dim=1, keepdim=True)
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                    if (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']):
         | 
| 434 | 
            +
                        decoder_output = torch.cat((decoder_output[:,:self.num_queries-1], cls_token, decoder_output[:,self.num_queries:2*self.num_queries-1]), dim=1)
         | 
| 435 | 
            +
                    else:
         | 
| 436 | 
            +
                        decoder_output = torch.cat((decoder_output[:,:self.num_queries-1], cls_token), dim=1)
         | 
| 437 | 
            +
             | 
| 438 | 
            +
                    # compute class, mask and bbox.
         | 
| 439 | 
            +
                    class_embed = decoder_output @ self.class_embed
         | 
| 440 | 
            +
                    # HACK do not compute similarity if mask is not on
         | 
| 441 | 
            +
                    outputs_class = self.lang_encoder.compute_similarity(class_embed, fake=(((not self.task_switch['mask']) and self.training)))
         | 
| 442 | 
            +
             | 
| 443 | 
            +
                    if self.task_switch['mask']:
         | 
| 444 | 
            +
                        mask_embed = self.mask_embed(decoder_output)
         | 
| 445 | 
            +
                        outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
         | 
| 446 | 
            +
             | 
| 447 | 
            +
                        # NOTE: prediction is of higher-resolution
         | 
| 448 | 
            +
                        # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
         | 
| 449 | 
            +
                        attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bicubic", align_corners=False, antialias=True)
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                        # must use bool type
         | 
| 452 | 
            +
                        # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
         | 
| 453 | 
            +
                        attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
         | 
| 454 | 
            +
                        attn_mask = attn_mask.detach()
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                        # NOTE: fill False for cls token (JY)
         | 
| 457 | 
            +
                        attn_mask[:, self.num_queries:self.num_queries+1].fill_(False)
         | 
| 458 | 
            +
                    else:
         | 
| 459 | 
            +
                        outputs_mask = None
         | 
| 460 | 
            +
                        attn_mask = torch.zeros((list(decoder_output.shape[:2]) + [attn_mask_target_size[0]*attn_mask_target_size[1]]), device=decoder_output.device).repeat(self.num_heads, 1, 1).bool()
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                    outputs_bbox = [None for i in range(len(decoder_output))]
         | 
| 463 | 
            +
                    if self.task_switch['bbox']:
         | 
| 464 | 
            +
                        outputs_bbox = self.bbox_embed(decoder_output)
         | 
| 465 | 
            +
             | 
| 466 | 
            +
                    outputs_caption = None
         | 
| 467 | 
            +
                    if self.task_switch['caption']:
         | 
| 468 | 
            +
                        outputs_caption = class_embed
         | 
| 469 | 
            +
                        
         | 
| 470 | 
            +
             | 
| 471 | 
            +
                    results = {
         | 
| 472 | 
            +
                        "outputs_class": outputs_class,
         | 
| 473 | 
            +
                        "outputs_mask": outputs_mask,
         | 
| 474 | 
            +
                        "outputs_bbox": outputs_bbox,
         | 
| 475 | 
            +
                        "attn_mask": attn_mask,
         | 
| 476 | 
            +
                        "outputs_caption": outputs_caption,
         | 
| 477 | 
            +
                        "outputs_captionting": outputs_captionting,
         | 
| 478 | 
            +
                    }
         | 
| 479 | 
            +
                    return results
         | 
| 480 | 
            +
             | 
| 481 | 
            +
                @torch.jit.unused
         | 
| 482 | 
            +
                def _set_aux_loss(self, outputs_class, outputs_seg_masks, outputs_boxes, outputs_captions):
         | 
| 483 | 
            +
                    # this is a workaround to make torchscript happy, as torchscript
         | 
| 484 | 
            +
                    # doesn't support dictionary with non-homogeneous values, such
         | 
| 485 | 
            +
                    # as a dict having both a Tensor and a list.
         | 
| 486 | 
            +
                    if self.mask_classification:
         | 
| 487 | 
            +
                        return [
         | 
| 488 | 
            +
                            {"pred_logits": a, "pred_masks": b, "pred_boxes": c, "pred_captions": d}
         | 
| 489 | 
            +
                            for a, b, c, d in zip(outputs_class[:-1], outputs_seg_masks[:-1], outputs_boxes[:-1], outputs_captions[:-1])
         | 
| 490 | 
            +
                        ]
         | 
| 491 | 
            +
                    else:
         | 
| 492 | 
            +
                        return [{"pred_masks": b} for b in outputs_seg_masks[:-1]]
         | 
| 493 | 
            +
             | 
| 494 | 
            +
             | 
| 495 | 
            +
            @register_decoder
         | 
| 496 | 
            +
            def get_xdecoder_interface(cfg, in_channels, lang_encoder, mask_classification, extra):
         | 
| 497 | 
            +
                return XDecoder(cfg, in_channels, lang_encoder, mask_classification, extra)
         | 
    	
        modeling/language/LangEncoder/__init__.py
    ADDED
    
    | @@ -0,0 +1,35 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from transformers import CLIPTokenizer, CLIPTokenizerFast
         | 
| 2 | 
            +
            from transformers import AutoTokenizer
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from .transformer import *
         | 
| 5 | 
            +
            from .build import *
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            def build_lang_encoder(config_encoder, tokenizer, verbose, **kwargs):
         | 
| 9 | 
            +
                model_name = config_encoder['NAME']
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                if not is_lang_encoder(model_name):
         | 
| 12 | 
            +
                    raise ValueError(f'Unkown model: {model_name}')
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                return lang_encoders(model_name)(config_encoder, tokenizer, verbose, **kwargs)
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            def build_tokenizer(config_encoder):
         | 
| 17 | 
            +
                tokenizer = None
         | 
| 18 | 
            +
                os.environ['TOKENIZERS_PARALLELISM'] = 'true'
         | 
| 19 | 
            +
                if config_encoder['TOKENIZER'] == 'clip':
         | 
| 20 | 
            +
                    pretrained_tokenizer = config_encoder.get(
         | 
| 21 | 
            +
                        'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32'
         | 
| 22 | 
            +
                    )
         | 
| 23 | 
            +
                    tokenizer = CLIPTokenizer.from_pretrained(pretrained_tokenizer)
         | 
| 24 | 
            +
                    tokenizer.add_special_tokens({'cls_token': tokenizer.eos_token})
         | 
| 25 | 
            +
                elif config_encoder['TOKENIZER'] == 'clip-fast':
         | 
| 26 | 
            +
                    pretrained_tokenizer = config_encoder.get(
         | 
| 27 | 
            +
                        'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32'
         | 
| 28 | 
            +
                    )
         | 
| 29 | 
            +
                    tokenizer = CLIPTokenizerFast.from_pretrained(pretrained_tokenizer, from_slow=True)
         | 
| 30 | 
            +
                elif config_encoder['TOKENIZER'] == 'biomed-clip':
         | 
| 31 | 
            +
                    tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext")
         | 
| 32 | 
            +
                else:
         | 
| 33 | 
            +
                    tokenizer = AutoTokenizer.from_pretrained(config_encoder['TOKENIZER'])
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                return tokenizer
         | 
    	
        modeling/language/LangEncoder/build.py
    ADDED
    
    | @@ -0,0 +1,16 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            _lang_encoders = {}
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            def register_lang_encoder(fn):
         | 
| 5 | 
            +
                module_name_split = fn.__module__.split('.')
         | 
| 6 | 
            +
                model_name = module_name_split[-1]
         | 
| 7 | 
            +
             | 
| 8 | 
            +
                _lang_encoders[model_name] = fn
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                return fn
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            def lang_encoders(model_name):
         | 
| 13 | 
            +
                return _lang_encoders[model_name]
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            def is_lang_encoder(model_name):
         | 
| 16 | 
            +
                return model_name in _lang_encoders
         | 
    	
        modeling/language/LangEncoder/transformer.py
    ADDED
    
    | @@ -0,0 +1,222 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from collections import OrderedDict
         | 
| 2 | 
            +
            from typing import Tuple, Union
         | 
| 3 | 
            +
            import logging
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            import torch.nn.functional as F
         | 
| 9 | 
            +
            from torch import nn
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from timm.models.layers import DropPath, trunc_normal_
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from .build import register_lang_encoder
         | 
| 14 | 
            +
            from utilities.distributed import is_main_process
         | 
| 15 | 
            +
            from utilities.model import register_norm_module
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            @register_norm_module
         | 
| 21 | 
            +
            class LayerNorm(nn.Module):
         | 
| 22 | 
            +
                def __init__(self, hidden_size, eps=1e-12):
         | 
| 23 | 
            +
                    """Construct a layernorm module in the TF style (epsilon inside the square root).
         | 
| 24 | 
            +
                    """
         | 
| 25 | 
            +
                    super(LayerNorm, self).__init__()
         | 
| 26 | 
            +
                    self.weight = nn.Parameter(torch.ones(hidden_size))
         | 
| 27 | 
            +
                    self.bias = nn.Parameter(torch.zeros(hidden_size))
         | 
| 28 | 
            +
                    self.variance_epsilon = eps
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                def forward(self, x):
         | 
| 31 | 
            +
                    pdtype = x.dtype
         | 
| 32 | 
            +
                    x = x.float()
         | 
| 33 | 
            +
                    u = x.mean(-1, keepdim=True)
         | 
| 34 | 
            +
                    s = (x - u).pow(2).mean(-1, keepdim=True)
         | 
| 35 | 
            +
                    x = (x - u) / torch.sqrt(s + self.variance_epsilon)
         | 
| 36 | 
            +
                    return self.weight * x.to(pdtype) + self.bias
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            class QuickGELU(nn.Module):
         | 
| 40 | 
            +
                def forward(self, x: torch.Tensor):
         | 
| 41 | 
            +
                    return x * torch.sigmoid(1.702 * x)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            class ResidualAttentionBlock(nn.Module):
         | 
| 45 | 
            +
                def __init__(self,
         | 
| 46 | 
            +
                             d_model: int,
         | 
| 47 | 
            +
                             n_head: int,
         | 
| 48 | 
            +
                             attn_mask: torch.Tensor = None,
         | 
| 49 | 
            +
                             drop_path: float = 0.0):
         | 
| 50 | 
            +
                    super().__init__()
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    self.attn = nn.MultiheadAttention(d_model, n_head)
         | 
| 53 | 
            +
                    self.ln_1 = LayerNorm(d_model)
         | 
| 54 | 
            +
                    self.mlp = nn.Sequential(OrderedDict([
         | 
| 55 | 
            +
                        ("c_fc", nn.Linear(d_model, d_model * 4)),
         | 
| 56 | 
            +
                        ("gelu", QuickGELU()),
         | 
| 57 | 
            +
                        ("c_proj", nn.Linear(d_model * 4, d_model))
         | 
| 58 | 
            +
                    ]))
         | 
| 59 | 
            +
                    self.ln_2 = LayerNorm(d_model)
         | 
| 60 | 
            +
                    self.attn_mask = attn_mask
         | 
| 61 | 
            +
                    self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):
         | 
| 64 | 
            +
                    self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \
         | 
| 65 | 
            +
                        if self.attn_mask is not None else None
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
                    return self.attn(
         | 
| 69 | 
            +
                        x, x, x,
         | 
| 70 | 
            +
                        key_padding_mask=key_padding_mask,
         | 
| 71 | 
            +
                        need_weights=False,
         | 
| 72 | 
            +
                        attn_mask=self.attn_mask
         | 
| 73 | 
            +
                    )[0]
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):
         | 
| 76 | 
            +
                    x = x + self.drop_path(self.attention(self.ln_1(x), key_padding_mask=key_padding_mask))
         | 
| 77 | 
            +
                    x = x + self.drop_path(self.mlp(self.ln_2(x)))
         | 
| 78 | 
            +
                    return x
         | 
| 79 | 
            +
             | 
| 80 | 
            +
             | 
| 81 | 
            +
            class Transformer(nn.Module):
         | 
| 82 | 
            +
                def __init__(self,
         | 
| 83 | 
            +
                             context_length: int,
         | 
| 84 | 
            +
                             vocab_size: int,
         | 
| 85 | 
            +
                             width: int,
         | 
| 86 | 
            +
                             layers: int,
         | 
| 87 | 
            +
                             heads: int,
         | 
| 88 | 
            +
                             drop_path: float = 0.0,
         | 
| 89 | 
            +
                             autogressive: bool =True):
         | 
| 90 | 
            +
                    super().__init__()
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    self.token_embedding = nn.Embedding(vocab_size, width)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    self.context_length = context_length
         | 
| 95 | 
            +
                    self.positional_embedding = nn.Parameter(
         | 
| 96 | 
            +
                        torch.empty(self.context_length, width)
         | 
| 97 | 
            +
                    )
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    self.width = width
         | 
| 100 | 
            +
                    self.layers = layers
         | 
| 101 | 
            +
                    self.autogressive = autogressive
         | 
| 102 | 
            +
                    attn_mask = self.build_attention_mask() if autogressive else None
         | 
| 103 | 
            +
                    dpr = [x.item() for x in torch.linspace(0, drop_path, layers)]  # stochastic depth decay rule
         | 
| 104 | 
            +
                    self.resblocks = nn.ModuleList(
         | 
| 105 | 
            +
                        [
         | 
| 106 | 
            +
                            ResidualAttentionBlock(width, heads, attn_mask, dpr[i])
         | 
| 107 | 
            +
                            for i in range(layers)
         | 
| 108 | 
            +
                        ]
         | 
| 109 | 
            +
                    )
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    self.ln_final = LayerNorm(width)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    trunc_normal_(self.positional_embedding, std=.02)
         | 
| 114 | 
            +
                    # nn.init.normal_(self.token_embedding, std=.02)
         | 
| 115 | 
            +
                    trunc_normal_(self.token_embedding.weight, std=.02)
         | 
| 116 | 
            +
                    self.apply(self._init_weights)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                @property
         | 
| 119 | 
            +
                def dim_out(self):
         | 
| 120 | 
            +
                    return self.width
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                def build_attention_mask(self):
         | 
| 123 | 
            +
                    # lazily create causal attention mask, with full attention between the vision tokens
         | 
| 124 | 
            +
                    # pytorch uses additive attention mask; fill with -inf
         | 
| 125 | 
            +
                    mask = torch.empty(self.context_length, self.context_length)
         | 
| 126 | 
            +
                    mask.fill_(float("-inf"))
         | 
| 127 | 
            +
                    mask.triu_(1)  # zero out the lower diagonal
         | 
| 128 | 
            +
                    return mask
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                def _init_weights(self, m):
         | 
| 131 | 
            +
                    if isinstance(m, (nn.Linear, nn.Conv2d)):
         | 
| 132 | 
            +
                        if is_main_process():
         | 
| 133 | 
            +
                            logger.info('=> init weight of Linear/Conv2d from trunc norm')
         | 
| 134 | 
            +
                        trunc_normal_(m.weight, std=0.02)
         | 
| 135 | 
            +
                        if m.bias is not None:
         | 
| 136 | 
            +
                            if is_main_process():
         | 
| 137 | 
            +
                                logger.info('=> init bias of Linear/Conv2d to zeros')
         | 
| 138 | 
            +
                            nn.init.constant_(m.bias, 0)
         | 
| 139 | 
            +
                    elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
         | 
| 140 | 
            +
                        nn.init.constant_(m.bias, 0)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                def load_pretrained(self, pretrained='', pretrained_layers=[], verbose=True):
         | 
| 143 | 
            +
                    if os.path.isfile(pretrained):
         | 
| 144 | 
            +
                        pretrained_dict = torch.load(pretrained, map_location='cpu')
         | 
| 145 | 
            +
                        logging.info(f'=> loading pretrained model {pretrained}')
         | 
| 146 | 
            +
                        model_dict = self.state_dict()
         | 
| 147 | 
            +
                        stripped_key = lambda x: x[13:] if x.startswith('lang_encoder.') else x
         | 
| 148 | 
            +
                        pretrained_dict = {
         | 
| 149 | 
            +
                            stripped_key(k): v for k, v in pretrained_dict.items()
         | 
| 150 | 
            +
                            if stripped_key(k) in model_dict.keys()
         | 
| 151 | 
            +
                        }
         | 
| 152 | 
            +
                        need_init_state_dict = {}
         | 
| 153 | 
            +
                        for k, v in pretrained_dict.items():
         | 
| 154 | 
            +
                            need_init = (
         | 
| 155 | 
            +
                                k.split('.')[0] in pretrained_layers
         | 
| 156 | 
            +
                                or pretrained_layers[0] == '*'
         | 
| 157 | 
            +
                            )
         | 
| 158 | 
            +
                            if need_init:
         | 
| 159 | 
            +
                                if verbose:
         | 
| 160 | 
            +
                                    logger.info(f'=> init {k} from {pretrained}')
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                                if 'positional_embedding' in k and v.size() != model_dict[k].size():
         | 
| 163 | 
            +
                                    positional_embedding_pretrained = v
         | 
| 164 | 
            +
                                    positional_embedding_current = model_dict[k]
         | 
| 165 | 
            +
                                    L1, nH1 = positional_embedding_pretrained.size()
         | 
| 166 | 
            +
                                    L2, nH2 = positional_embedding_current.size()
         | 
| 167 | 
            +
                                    if nH1 != nH2:
         | 
| 168 | 
            +
                                        logger.info(f"Error in loading {k}, passing")
         | 
| 169 | 
            +
                                    else:
         | 
| 170 | 
            +
                                        if L1 != L2:
         | 
| 171 | 
            +
                                            logger.info(
         | 
| 172 | 
            +
                                                '=> load_pretrained: resized variant: {} to {}'
         | 
| 173 | 
            +
                                                    .format((L1, nH1), (L2, nH2))
         | 
| 174 | 
            +
                                            )
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                                            posemb = positional_embedding_pretrained.float()
         | 
| 177 | 
            +
                                            posemb_grid = posemb.unsqueeze(dim=0).permute(0, 2, 1)
         | 
| 178 | 
            +
                                            posemb_grid = torch.nn.functional.interpolate(posemb_grid, size=L2, mode='linear')
         | 
| 179 | 
            +
                                            posemb_grid = posemb_grid.permute(0, 2, 1).squeeze(dim=0)
         | 
| 180 | 
            +
                                            v = posemb_grid
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                                need_init_state_dict[k] = v
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                        self.load_state_dict(need_init_state_dict, strict=False)
         | 
| 185 | 
            +
             | 
| 186 | 
            +
             | 
| 187 | 
            +
                @torch.jit.ignore
         | 
| 188 | 
            +
                def no_weight_decay(self):
         | 
| 189 | 
            +
                    return {
         | 
| 190 | 
            +
                        'positional_embedding',
         | 
| 191 | 
            +
                        'token_embedding',
         | 
| 192 | 
            +
                    }
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                def forward(self, input_ids, attention_mask=None):
         | 
| 195 | 
            +
                    key_padding_mask = (attention_mask == 0) if (not self.autogressive and attention_mask is not None) else None
         | 
| 196 | 
            +
                    # key_padding_mask = (input_ids == 0) if not self.autogressive else None
         | 
| 197 | 
            +
                    x = self.token_embedding(input_ids)  # [batch_size, n_ctx, d_model]
         | 
| 198 | 
            +
                    x = x + self.positional_embedding
         | 
| 199 | 
            +
                    x = x.permute(1, 0, 2)  # NLD -> LND
         | 
| 200 | 
            +
                    for block in self.resblocks:
         | 
| 201 | 
            +
                        x = block(x, key_padding_mask)
         | 
| 202 | 
            +
                    x = x.permute(1, 0, 2)  # LND -> NLD
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                    x = self.ln_final(x)
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                    return {'last_hidden_state': x}
         | 
| 207 | 
            +
             | 
| 208 | 
            +
             | 
| 209 | 
            +
            @register_lang_encoder
         | 
| 210 | 
            +
            def lang_encoder(config_encoder, tokenizer, verbose, **kwargs):
         | 
| 211 | 
            +
                transformer = Transformer(
         | 
| 212 | 
            +
                    context_length=config_encoder['CONTEXT_LENGTH'],
         | 
| 213 | 
            +
                    vocab_size=tokenizer.vocab_size,
         | 
| 214 | 
            +
                    width=config_encoder['WIDTH'],
         | 
| 215 | 
            +
                    layers=config_encoder['LAYERS'],
         | 
| 216 | 
            +
                    heads=config_encoder['HEADS'],
         | 
| 217 | 
            +
                    autogressive=config_encoder.get('AUTOGRESSIVE', True)
         | 
| 218 | 
            +
                )
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                if config_encoder.get('LOAD_PRETRAINED', False):
         | 
| 221 | 
            +
                    transformer.load_pretrained(config_encoder['PRETRAINED'], config_encoder.get('PRETRAINED_LAYERS', ['*']))
         | 
| 222 | 
            +
                return transformer
         | 
    	
        modeling/language/__init__.py
    ADDED
    
    | @@ -0,0 +1,10 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .vlpencoder import *
         | 
| 2 | 
            +
            from .build import *
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            def build_language_encoder(config, **kwargs):
         | 
| 5 | 
            +
                model_name = config['MODEL']['TEXT']['ARCH']
         | 
| 6 | 
            +
             | 
| 7 | 
            +
                if not is_model(model_name):
         | 
| 8 | 
            +
                    raise ValueError(f'Unkown model: {model_name}')
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                return model_entrypoints(model_name)(config, **kwargs)
         | 
    	
        modeling/language/build.py
    ADDED
    
    | @@ -0,0 +1,14 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            _model_entrypoints = {}
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            def register_model(fn):
         | 
| 5 | 
            +
                module_name_split = fn.__module__.split('.')
         | 
| 6 | 
            +
                model_name = module_name_split[-1]
         | 
| 7 | 
            +
                _model_entrypoints[model_name] = fn
         | 
| 8 | 
            +
                return fn
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            def model_entrypoints(model_name):
         | 
| 11 | 
            +
                return _model_entrypoints[model_name]
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            def is_model(model_name):
         | 
| 14 | 
            +
                return model_name in _model_entrypoints
         | 
    	
        modeling/language/loss.py
    ADDED
    
    | @@ -0,0 +1,232 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # --------------------------------------------------------
         | 
| 2 | 
            +
            # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
         | 
| 3 | 
            +
            # Copyright (c) 2022 Microsoft
         | 
| 4 | 
            +
            # Licensed under The MIT License [see LICENSE for details]
         | 
| 5 | 
            +
            # Written by Xueyan Zou (xueyan@cs.wisc.edu)
         | 
| 6 | 
            +
            # --------------------------------------------------------
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import pickle
         | 
| 9 | 
            +
            from distutils import log
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import torch
         | 
| 12 | 
            +
            import torch.nn.functional as F
         | 
| 13 | 
            +
            import torch.distributed as dist
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from einops import rearrange, repeat
         | 
| 16 | 
            +
            from timm.loss import SoftTargetCrossEntropy
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            soft_cross_entropy = SoftTargetCrossEntropy()
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            def is_dist_initialized():
         | 
| 21 | 
            +
                return torch.distributed.is_initialized()
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            def get_world_size():
         | 
| 24 | 
            +
                if is_dist_initialized():
         | 
| 25 | 
            +
                    return torch.distributed.get_world_size()
         | 
| 26 | 
            +
                return 1
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            def get_rank():
         | 
| 29 | 
            +
                if is_dist_initialized():
         | 
| 30 | 
            +
                    return dist.get_rank()
         | 
| 31 | 
            +
                return 0
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            def all_gather_grad(x):
         | 
| 34 | 
            +
                if get_world_size() > 1:
         | 
| 35 | 
            +
                    all_x = [torch.zeros_like(x) for _ in range(get_world_size())]
         | 
| 36 | 
            +
                    torch.distributed.all_gather(all_x, x)
         | 
| 37 | 
            +
                    all_x[torch.distributed.get_rank()] = x
         | 
| 38 | 
            +
                    x = torch.cat(all_x, dim=0)
         | 
| 39 | 
            +
                return x
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            def vl_multilabel_contrastive_loss(image_feat, text_feat, temperature=1):
         | 
| 42 | 
            +
                """
         | 
| 43 | 
            +
                Args:
         | 
| 44 | 
            +
                    image_feat (torch.Tensor): shape [B, L1, C] # B: batch_size, L1: 1, C: 256
         | 
| 45 | 
            +
                    text_feat (torch.Tensor): shape [B, L2, C] # B:batch_size, L2: number of selected nouns, C: 256
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                Returns:
         | 
| 48 | 
            +
                """
         | 
| 49 | 
            +
                # [B, L1, C], L1 = 1
         | 
| 50 | 
            +
                # image_feat = F.normalize(image_feat, dim=-1)
         | 
| 51 | 
            +
                # [B, L2, C]
         | 
| 52 | 
            +
                # text_feat = F.normalize(text_feat, dim=-1)
         | 
| 53 | 
            +
                # HACK: normalize outside
         | 
| 54 | 
            +
                
         | 
| 55 | 
            +
                # [B, L1, L2]
         | 
| 56 | 
            +
                dist_per_img = image_feat @ rearrange(text_feat, 'b l c -> b c l')    
         | 
| 57 | 
            +
                # [B, L2, L1]
         | 
| 58 | 
            +
                dist_per_text = text_feat @ rearrange(image_feat, 'b l c -> b c l')
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                batch = image_feat.shape[0]
         | 
| 61 | 
            +
                img_len = image_feat.shape[1]
         | 
| 62 | 
            +
                text_len = text_feat.shape[1]
         | 
| 63 | 
            +
                # [B, L1, L2]
         | 
| 64 | 
            +
                pos_labels_batch_img = rearrange(torch.ones_like(dist_per_text) / dist_per_text.size(1), 'b l2 l1 -> b l1 l2')
         | 
| 65 | 
            +
                # [B, L2, L1]
         | 
| 66 | 
            +
                pos_labels_batch_text = rearrange(torch.ones_like(dist_per_img) / dist_per_img.size(1), 'b l1 l2 -> b l2 l1')
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                image_x = rearrange(image_feat, 'b l c -> (b l) c')
         | 
| 69 | 
            +
                text_x = rearrange(text_feat, 'b l c -> (b l) c')
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                logits_per_img = image_x @ all_gather_grad(text_x).t()
         | 
| 72 | 
            +
                logits_per_text = text_x @ all_gather_grad(image_x).t()
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                # get label globally
         | 
| 75 | 
            +
                # [B, L1, B, L2, W]
         | 
| 76 | 
            +
                labels_per_img = F.one_hot(
         | 
| 77 | 
            +
                    torch.ones(batch, img_len, batch, text_len, dtype=torch.long, device=image_x.device) * get_rank(),
         | 
| 78 | 
            +
                    num_classes=get_world_size()).to(image_x.dtype)
         | 
| 79 | 
            +
                labels_per_img *= rearrange(pos_labels_batch_img, 'b l1 l2 -> b l1 1 l2 1') * repeat(
         | 
| 80 | 
            +
                    torch.eye(batch, dtype=image_x.dtype, device=image_x.device), 'b1 b2 -> b1 1 b2 1 1')
         | 
| 81 | 
            +
                # [BxL1, WxBxL2]
         | 
| 82 | 
            +
                labels_per_img = rearrange(labels_per_img, 'b1 l1 b2 l2 w -> (b1 l1) (w b2 l2)')
         | 
| 83 | 
            +
                # [B, L2, B, L1, W]
         | 
| 84 | 
            +
                labels_per_text = F.one_hot(
         | 
| 85 | 
            +
                    torch.ones(batch, text_len, batch, img_len, dtype=torch.long, device=text_x.device) * get_rank(),
         | 
| 86 | 
            +
                    num_classes=get_world_size()).to(text_x.dtype)
         | 
| 87 | 
            +
                labels_per_text *= rearrange(pos_labels_batch_text, 'b l2 l1 -> b l2 1 l1 1') * repeat(
         | 
| 88 | 
            +
                    torch.eye(batch, dtype=text_x.dtype, device=image_x.device), 'b2 b1 -> b2 1 b1 1 1')
         | 
| 89 | 
            +
                # [BxL2, WxBxL1]
         | 
| 90 | 
            +
                labels_per_text = rearrange(labels_per_text, 'b2 l2 b1 l1 w -> (b2 l2) (w b1 l1)')
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                logit_scale = temperature.exp().clamp(max=100)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                loss_img = soft_cross_entropy(logit_scale * logits_per_img, labels_per_img)
         | 
| 95 | 
            +
                loss_text = soft_cross_entropy(logit_scale * logits_per_text, labels_per_text)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                loss = 0.5 * (loss_img + loss_text)
         | 
| 98 | 
            +
                return loss
         | 
| 99 | 
            +
             | 
| 100 | 
            +
            def vl_contrastive_loss(image_feat, text_feat, temperature=1):
         | 
| 101 | 
            +
                # if image_id or text_id is None, it should be None across all GPUs
         | 
| 102 | 
            +
                # image_feat = F.normalize(image_feat, dim=1)
         | 
| 103 | 
            +
                # text_feat = F.normalize(text_feat, dim=1)
         | 
| 104 | 
            +
                # handle normalization outside
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                # add the following 4 lines
         | 
| 107 | 
            +
                image_feat = all_gather_grad(image_feat)
         | 
| 108 | 
            +
                text_feat = all_gather_grad(text_feat)
         | 
| 109 | 
            +
                
         | 
| 110 | 
            +
                logits = torch.matmul(image_feat, text_feat.t())
         | 
| 111 | 
            +
                logit_scale = temperature.exp().clamp(max=100)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                gt = torch.arange(logits.shape[0], device=logits.device)
         | 
| 114 | 
            +
                loss1 = F.cross_entropy(logit_scale * logits, gt)
         | 
| 115 | 
            +
                loss2 = F.cross_entropy(logit_scale * logits.t(), gt)
         | 
| 116 | 
            +
                return (loss1 + loss2) / 2 # scale it up by the number of GPUs
         | 
| 117 | 
            +
             | 
| 118 | 
            +
             | 
| 119 | 
            +
            def all_gather_pickle(data, device):
         | 
| 120 | 
            +
                """
         | 
| 121 | 
            +
                Run all_gather on arbitrary picklable data (not necessarily tensors)
         | 
| 122 | 
            +
                Args:
         | 
| 123 | 
            +
                    data: any picklable object
         | 
| 124 | 
            +
                Returns:
         | 
| 125 | 
            +
                    list[data]: list of data gathered from each rank
         | 
| 126 | 
            +
                """
         | 
| 127 | 
            +
                world_size = get_world_size()
         | 
| 128 | 
            +
                if world_size == 1:
         | 
| 129 | 
            +
                    return [data]
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                # serialized to a Tensor
         | 
| 132 | 
            +
                buffer = pickle.dumps(data)
         | 
| 133 | 
            +
                storage = torch.ByteStorage.from_buffer(buffer)
         | 
| 134 | 
            +
                tensor = torch.ByteTensor(storage).to(device)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                # obtain Tensor size of each rank
         | 
| 137 | 
            +
                local_size = torch.LongTensor([tensor.numel()]).cuda()
         | 
| 138 | 
            +
                size_list = [torch.LongTensor([0]).cuda() for _ in range(world_size)]
         | 
| 139 | 
            +
                dist.all_gather(size_list, local_size)
         | 
| 140 | 
            +
                size_list = [int(size.item()) for size in size_list]
         | 
| 141 | 
            +
                max_size = max(size_list)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                # receiving Tensor from all ranks
         | 
| 144 | 
            +
                # we pad the tensor because torch all_gather does not support
         | 
| 145 | 
            +
                # gathering tensors of different shapes
         | 
| 146 | 
            +
                tensor_list = []
         | 
| 147 | 
            +
                for _ in size_list:
         | 
| 148 | 
            +
                    tensor_list.append(torch.ByteTensor(size=(max_size,)).cuda())
         | 
| 149 | 
            +
                if local_size != max_size:
         | 
| 150 | 
            +
                    padding = torch.ByteTensor(size=(max_size - local_size,)).cuda()
         | 
| 151 | 
            +
                    tensor = torch.cat((tensor, padding), dim=0)
         | 
| 152 | 
            +
                dist.all_gather(tensor_list, tensor)
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                data_list = []
         | 
| 155 | 
            +
                for size, tensor in zip(size_list, tensor_list):
         | 
| 156 | 
            +
                    buffer = tensor.cpu().numpy().tobytes()[:size]
         | 
| 157 | 
            +
                    data_list.append(pickle.loads(buffer))
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                return data_list
         | 
| 160 | 
            +
             | 
| 161 | 
            +
            def all_gather_arbitary_tensor(tensor):
         | 
| 162 | 
            +
                if get_world_size() > 1:
         | 
| 163 | 
            +
                    device = tensor.device
         | 
| 164 | 
            +
                    tensor_batch = all_gather_pickle(tensor.cpu(), device)
         | 
| 165 | 
            +
                    tensor_batch = [x.to(device) for x in tensor_batch]
         | 
| 166 | 
            +
                    tensor_batch[torch.distributed.get_rank()] = tensor
         | 
| 167 | 
            +
                    tensor_batch = torch.cat(tensor_batch, dim=0)
         | 
| 168 | 
            +
                else:
         | 
| 169 | 
            +
                    tensor_batch = tensor
         | 
| 170 | 
            +
                return tensor_batch
         | 
| 171 | 
            +
             | 
| 172 | 
            +
            def ql_contrastive_loss(image_feat, text_feat, temperature=1):
         | 
| 173 | 
            +
                # add the following 4 lines
         | 
| 174 | 
            +
                image_feat = all_gather_arbitary_tensor(image_feat)
         | 
| 175 | 
            +
                text_feat = all_gather_arbitary_tensor(text_feat)
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                logits = torch.matmul(image_feat, text_feat.t())
         | 
| 178 | 
            +
                logit_scale = temperature.exp().clamp(max=100)
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                gt = torch.arange(logits.shape[0], device=logits.device)
         | 
| 181 | 
            +
                loss1 = F.cross_entropy(logit_scale * logits, gt)
         | 
| 182 | 
            +
                loss2 = F.cross_entropy(logit_scale * logits.t(), gt)
         | 
| 183 | 
            +
                return (loss1 + loss2) / 2 # scale it up by the number of GPUs
         | 
| 184 | 
            +
             | 
| 185 | 
            +
            def vl_similarity(image_feat, text_feat, temperature=1):
         | 
| 186 | 
            +
                # Only support single GPU for now.
         | 
| 187 | 
            +
                logits = torch.matmul(image_feat, text_feat.t())
         | 
| 188 | 
            +
                logits = temperature.exp().clamp(max=100) * logits
         | 
| 189 | 
            +
                return logits
         | 
| 190 | 
            +
             | 
| 191 | 
            +
            def ql_multi_contrastive_loss(image_feat, text_feat, text_hash, temperature=1):
         | 
| 192 | 
            +
                # add the following 4 lines
         | 
| 193 | 
            +
                image_feat = all_gather_arbitary_tensor(image_feat)
         | 
| 194 | 
            +
                text_feat = all_gather_arbitary_tensor(text_feat)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                text_hash_batch = all_gather_pickle(text_hash, text_feat.device)
         | 
| 197 | 
            +
                text_hash_all = torch.cat(text_hash_batch)
         | 
| 198 | 
            +
                
         | 
| 199 | 
            +
                text_hash_all_unique = torch.unique(text_hash_all).tolist()
         | 
| 200 | 
            +
                gt = torch.zeros((image_feat.shape[0], len(text_hash_all_unique)), device=text_feat.device)
         | 
| 201 | 
            +
                text_hash_all = text_hash_all.tolist()
         | 
| 202 | 
            +
                text_feat_unique = torch.stack([text_feat[text_hash_all.index(txt)] for txt in text_hash_all_unique])
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                for idx, txt in enumerate(text_hash_all):
         | 
| 205 | 
            +
                    gt[idx][text_hash_all_unique.index(txt)] = 1
         | 
| 206 | 
            +
                
         | 
| 207 | 
            +
                logits = torch.matmul(image_feat, text_feat_unique.t())
         | 
| 208 | 
            +
                logits = logits*temperature.exp().clamp(max=100)
         | 
| 209 | 
            +
                
         | 
| 210 | 
            +
                loss_img = soft_cross_entropy(logits, gt)
         | 
| 211 | 
            +
                loss_text = soft_cross_entropy(logits.t(), gt.t() / gt.t().sum(-1, keepdim=True))
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                loss = 0.7 * loss_img + 0.3 * loss_text
         | 
| 214 | 
            +
                return loss
         | 
| 215 | 
            +
             | 
| 216 | 
            +
            def image_text_contrastive_loss_queue(image_feat_inp, text_feat_inp, lang_enc, training):
         | 
| 217 | 
            +
                # add the following 4 lines
         | 
| 218 | 
            +
                image_feat = all_gather_grad(image_feat_inp.contiguous())
         | 
| 219 | 
            +
                text_feat = all_gather_grad(text_feat_inp.contiguous())
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                image_feat = image_feat / (image_feat.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 222 | 
            +
                text_feat = text_feat / (text_feat.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                temperature = lang_enc.logit_scale
         | 
| 225 | 
            +
                logits = torch.matmul(image_feat, text_feat.t())
         | 
| 226 | 
            +
                logit_scale = temperature.exp().clamp(max=100)
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                gt = torch.arange(logits.shape[0], device=logits.device)
         | 
| 229 | 
            +
                loss1 = F.cross_entropy(logit_scale * logits, gt)
         | 
| 230 | 
            +
                loss2 = F.cross_entropy(logit_scale * logits.t(), gt)
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                return (loss1 + loss2) / 2 # scale it up by the number of GPUs
         | 
    	
        modeling/language/misc.py
    ADDED
    
    | @@ -0,0 +1,66 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import random
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import nltk
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from utilities.constants import IMAGENET_DEFAULT_TEMPLATES
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            nltk.download('punkt', quiet=True)
         | 
| 10 | 
            +
            nltk.download('averaged_perceptron_tagger', quiet=True)
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            def get_tag(tokenized, tags):
         | 
| 13 | 
            +
                if not isinstance(tags, (list, tuple)):
         | 
| 14 | 
            +
                    tags = [tags]
         | 
| 15 | 
            +
                ret = []
         | 
| 16 | 
            +
                for (word, pos) in nltk.pos_tag(tokenized):
         | 
| 17 | 
            +
                    for tag in tags:
         | 
| 18 | 
            +
                        if pos == tag:
         | 
| 19 | 
            +
                            ret.append(word)
         | 
| 20 | 
            +
                return ret
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            def get_noun_phrase(tokenized):
         | 
| 23 | 
            +
                # Taken from Su Nam Kim Paper...
         | 
| 24 | 
            +
                grammar = r"""
         | 
| 25 | 
            +
                    NBAR:
         | 
| 26 | 
            +
                        {<NN.*|JJ>*<NN.*>}  # Nouns and Adjectives, terminated with Nouns
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                    NP:
         | 
| 29 | 
            +
                        {<NBAR>}
         | 
| 30 | 
            +
                        {<NBAR><IN><NBAR>}  # Above, connected with in/of/etc...
         | 
| 31 | 
            +
                """
         | 
| 32 | 
            +
                chunker = nltk.RegexpParser(grammar)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                chunked = chunker.parse(nltk.pos_tag(tokenized))
         | 
| 35 | 
            +
                continuous_chunk = []
         | 
| 36 | 
            +
                current_chunk = []
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                for subtree in chunked:
         | 
| 39 | 
            +
                    if isinstance(subtree, nltk.Tree):
         | 
| 40 | 
            +
                        current_chunk.append(' '.join([token for token, pos in subtree.leaves()]))
         | 
| 41 | 
            +
                    elif current_chunk:
         | 
| 42 | 
            +
                        named_entity = ' '.join(current_chunk)
         | 
| 43 | 
            +
                        if named_entity not in continuous_chunk:
         | 
| 44 | 
            +
                            continuous_chunk.append(named_entity)
         | 
| 45 | 
            +
                            current_chunk = []
         | 
| 46 | 
            +
                    else:
         | 
| 47 | 
            +
                        continue
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                return continuous_chunk
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            def text_noun_with_prompt_all(text, phrase_prob=0.0, append_text=True):
         | 
| 52 | 
            +
                tokenized = nltk.word_tokenize(text)
         | 
| 53 | 
            +
                
         | 
| 54 | 
            +
                if random.random() >= phrase_prob:
         | 
| 55 | 
            +
                    nouns = get_tag(tokenized, ['NN', 'NNS', 'NNP'])
         | 
| 56 | 
            +
                else:
         | 
| 57 | 
            +
                    nouns = get_noun_phrase(tokenized)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
                prompt_texts = [np.random.choice(IMAGENET_DEFAULT_TEMPLATES).format(noun) for noun in nouns]
         | 
| 61 | 
            +
                
         | 
| 62 | 
            +
                if append_text:
         | 
| 63 | 
            +
                    prompt_texts += [text]
         | 
| 64 | 
            +
                    nouns += [text]
         | 
| 65 | 
            +
                
         | 
| 66 | 
            +
                return prompt_texts, nouns
         | 
    	
        modeling/language/vlpencoder.py
    ADDED
    
    | @@ -0,0 +1,206 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # --------------------------------------------------------
         | 
| 2 | 
            +
            # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
         | 
| 3 | 
            +
            # Copyright (c) 2022 Microsoft
         | 
| 4 | 
            +
            # Licensed under The MIT License [see LICENSE for details]
         | 
| 5 | 
            +
            # Written by Xueyan Zou (xueyan@cs.wisc.edu)
         | 
| 6 | 
            +
            # --------------------------------------------------------
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
            from torch import nn
         | 
| 10 | 
            +
            from torch.nn import functional as F
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from timm.models.layers import trunc_normal_
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from .build import register_model
         | 
| 15 | 
            +
            from ..utils import configurable
         | 
| 16 | 
            +
            from .LangEncoder import build_tokenizer, build_lang_encoder
         | 
| 17 | 
            +
            from utilities.prompt_engineering import prompt_engineering, get_prompt_templates
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from transformers import AutoTokenizer, AutoModel
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            class LanguageEncoder(nn.Module):
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                @configurable
         | 
| 24 | 
            +
                def __init__(
         | 
| 25 | 
            +
                    self,
         | 
| 26 | 
            +
                    tokenizer,
         | 
| 27 | 
            +
                    tokenizer_type,
         | 
| 28 | 
            +
                    lang_encoder,
         | 
| 29 | 
            +
                    lang_projection,
         | 
| 30 | 
            +
                    max_token_num,
         | 
| 31 | 
            +
                    queue_operator,
         | 
| 32 | 
            +
                ):
         | 
| 33 | 
            +
                    super().__init__()
         | 
| 34 | 
            +
                    # seg
         | 
| 35 | 
            +
                    self.tokenizer = tokenizer
         | 
| 36 | 
            +
                    self.tokenizer_type = tokenizer_type
         | 
| 37 | 
            +
                    self.lang_encoder = lang_encoder
         | 
| 38 | 
            +
                    self.lang_proj = lang_projection
         | 
| 39 | 
            +
                    self.max_token_num = max_token_num
         | 
| 40 | 
            +
                    self.logit_scale = nn.Parameter(torch.ones([]))
         | 
| 41 | 
            +
                    
         | 
| 42 | 
            +
                    # captioning & retrieval
         | 
| 43 | 
            +
                    for key, value in queue_operator.items():
         | 
| 44 | 
            +
                        self.register_buffer(key, value)
         | 
| 45 | 
            +
                        
         | 
| 46 | 
            +
                    self.biomed_encoder = AutoModel.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext")
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                @classmethod
         | 
| 49 | 
            +
                def from_config(cls, cfg):
         | 
| 50 | 
            +
                    # build up text encoder for seg
         | 
| 51 | 
            +
                    tokenizer = build_tokenizer(cfg['MODEL']['TEXT'])
         | 
| 52 | 
            +
                    tokenizer_type = cfg['MODEL']['TEXT']['TOKENIZER']
         | 
| 53 | 
            +
                    lang_encoder = build_lang_encoder(cfg['MODEL']['TEXT'], tokenizer, cfg['VERBOSE'])
         | 
| 54 | 
            +
                    max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
         | 
| 55 | 
            +
                    
         | 
| 56 | 
            +
                    dim_lang = cfg['MODEL']['TEXT']['WIDTH']
         | 
| 57 | 
            +
                    dim_projection = cfg['MODEL']['DIM_PROJ']
         | 
| 58 | 
            +
                    lang_projection = nn.Parameter(torch.empty(dim_lang, dim_projection))
         | 
| 59 | 
            +
                    trunc_normal_(lang_projection, std=.02)
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    # tested not working better      
         | 
| 62 | 
            +
                    queue_operator = {}
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    return {
         | 
| 65 | 
            +
                        "tokenizer": tokenizer,
         | 
| 66 | 
            +
                        "tokenizer_type": tokenizer_type,
         | 
| 67 | 
            +
                        "lang_encoder": lang_encoder,
         | 
| 68 | 
            +
                        "lang_projection": lang_projection,
         | 
| 69 | 
            +
                        "max_token_num": max_token_num,
         | 
| 70 | 
            +
                        "queue_operator": queue_operator,
         | 
| 71 | 
            +
                    }
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def get_text_embeddings(self, class_names, name='default', is_eval=False, add_bgd=False, prompt=True, norm=True, store_buffer=None):
         | 
| 74 | 
            +
                    if not is_eval:
         | 
| 75 | 
            +
                        if prompt:
         | 
| 76 | 
            +
                            # randomly sample one template
         | 
| 77 | 
            +
                            arbitary_concepts = [
         | 
| 78 | 
            +
                                prompt_engineering(class_names[label].replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \
         | 
| 79 | 
            +
                                for label in range(len(class_names))
         | 
| 80 | 
            +
                            ]
         | 
| 81 | 
            +
                            if add_bgd:
         | 
| 82 | 
            +
                                arbitary_concepts.append("A background in coco.")
         | 
| 83 | 
            +
                        else:
         | 
| 84 | 
            +
                            arbitary_concepts = class_names
         | 
| 85 | 
            +
                        
         | 
| 86 | 
            +
                        input_ids = []
         | 
| 87 | 
            +
                        attention_masks = []
         | 
| 88 | 
            +
                        for txt in arbitary_concepts:
         | 
| 89 | 
            +
                            tokens = self.tokenizer(
         | 
| 90 | 
            +
                                txt, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'
         | 
| 91 | 
            +
                            )
         | 
| 92 | 
            +
                            tokens['input_ids'].squeeze_()
         | 
| 93 | 
            +
                            tokens['attention_mask'].squeeze_()
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                            input_ids.append(tokens['input_ids'])
         | 
| 96 | 
            +
                            attention_masks.append(tokens['attention_mask'])
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                        arbitary_tokens = torch.stack(input_ids)
         | 
| 99 | 
            +
                        arbitary_attention_masks = torch.stack(attention_masks)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                        text_emb = self.forward_language((arbitary_tokens.cuda(), arbitary_attention_masks.cuda()), norm=norm)
         | 
| 102 | 
            +
                        setattr(self, '{}_text_embeddings'.format(name), text_emb)
         | 
| 103 | 
            +
                    else:
         | 
| 104 | 
            +
                        with torch.no_grad():
         | 
| 105 | 
            +
                            def extract_mean_emb(txts):
         | 
| 106 | 
            +
                                tokens = self.tokenizer(
         | 
| 107 | 
            +
                                    txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'
         | 
| 108 | 
            +
                                )
         | 
| 109 | 
            +
                                clss_embedding = self.forward_language((tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()), norm=norm)
         | 
| 110 | 
            +
                                clss_embedding = clss_embedding.mean(dim=0)
         | 
| 111 | 
            +
                                clss_embedding /= clss_embedding.norm()
         | 
| 112 | 
            +
                                return clss_embedding
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                            templates = get_prompt_templates()
         | 
| 115 | 
            +
                            clss_embeddings = []
         | 
| 116 | 
            +
                            if prompt:
         | 
| 117 | 
            +
                                for clss in class_names:
         | 
| 118 | 
            +
                                    txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff','')) for template in templates]
         | 
| 119 | 
            +
                                    clss_embeddings.append(extract_mean_emb(txts))
         | 
| 120 | 
            +
                            else:
         | 
| 121 | 
            +
                                for clss in class_names:
         | 
| 122 | 
            +
                                    clss_embeddings.append(extract_mean_emb([clss]))
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                            if add_bgd:
         | 
| 125 | 
            +
                                txts = ["A background in coco."]
         | 
| 126 | 
            +
                                clss_embeddings.append(extract_mean_emb(txts))
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                            text_emb = torch.stack(clss_embeddings, dim=0)
         | 
| 129 | 
            +
                            setattr(self, '{}_text_embeddings'.format(name), text_emb)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                def reset_text_embeddings(self, name='default'):
         | 
| 132 | 
            +
                    pass
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                def get_text_token_embeddings(self, txts, name='default', token=False, norm=False):
         | 
| 135 | 
            +
                    if not token:
         | 
| 136 | 
            +
                        tokens = self.tokenizer(
         | 
| 137 | 
            +
                            txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'
         | 
| 138 | 
            +
                        )
         | 
| 139 | 
            +
                        tokens = {key: value.cuda() for key, value in tokens.items()}
         | 
| 140 | 
            +
                    else:
         | 
| 141 | 
            +
                        tokens = txts
         | 
| 142 | 
            +
                    token_emb, class_emb = self.forward_language_token((tokens['input_ids'], tokens['attention_mask']), norm=norm)
         | 
| 143 | 
            +
                    ret = {"tokens": tokens,
         | 
| 144 | 
            +
                            "token_emb": token_emb,
         | 
| 145 | 
            +
                            "class_emb": class_emb,}
         | 
| 146 | 
            +
                    setattr(self, '{}_token_embeddings'.format(name), ret)
         | 
| 147 | 
            +
                    return ret
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                def forward_language(self, texts, norm=True):
         | 
| 150 | 
            +
                    if self.tokenizer_type == 'biomed-clip':
         | 
| 151 | 
            +
                        with torch.no_grad():  # Disable gradient calculation
         | 
| 152 | 
            +
                            outputs = self.biomed_encoder(*texts)
         | 
| 153 | 
            +
                        # Extract the last hidden state
         | 
| 154 | 
            +
                        x = outputs['last_hidden_state']
         | 
| 155 | 
            +
                        x = x[:, 0]  # Get the [CLS] token's embeddings for all examples
         | 
| 156 | 
            +
                    else:
         | 
| 157 | 
            +
                        x = self.lang_encoder(*texts)
         | 
| 158 | 
            +
                        x = x['last_hidden_state']
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                        if self.tokenizer_type == 'clip':
         | 
| 161 | 
            +
                            x = x[torch.arange(x.size(0)), texts[0].argmax(dim=-1)]
         | 
| 162 | 
            +
                        else:
         | 
| 163 | 
            +
                            x = x[:, 0]
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    x = x @ self.lang_proj
         | 
| 166 | 
            +
                    if norm:
         | 
| 167 | 
            +
                        x = x / (x.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 168 | 
            +
                    return x
         | 
| 169 | 
            +
                
         | 
| 170 | 
            +
                def forward_language_token(self, texts, norm=False):
         | 
| 171 | 
            +
                    if self.tokenizer_type == 'biomed-clip':
         | 
| 172 | 
            +
                        with torch.no_grad():  # Disable gradient calculation
         | 
| 173 | 
            +
                            outputs = self.biomed_encoder(*texts)
         | 
| 174 | 
            +
                        # Extract the last hidden state
         | 
| 175 | 
            +
                        token_x = outputs['last_hidden_state']
         | 
| 176 | 
            +
                        class_x = token_x[:, 0]  # Get the [CLS] token's embeddings for all examples
         | 
| 177 | 
            +
                    else:
         | 
| 178 | 
            +
                        x = self.lang_encoder(*texts)
         | 
| 179 | 
            +
                        token_x = x['last_hidden_state']
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                        if self.tokenizer_type == 'clip':
         | 
| 182 | 
            +
                            class_x = token_x[torch.arange(token_x.size(0)), texts[0].argmax(dim=-1)]
         | 
| 183 | 
            +
                        else:
         | 
| 184 | 
            +
                            class_x = token_x[:, 0]
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    class_x = class_x @ self.lang_proj
         | 
| 187 | 
            +
                    token_x = token_x @ self.lang_proj
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                    if norm:
         | 
| 190 | 
            +
                        class_x = class_x / (class_x.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 191 | 
            +
                        token_x = token_x / (token_x.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    return token_x, class_x
         | 
| 194 | 
            +
                
         | 
| 195 | 
            +
                def compute_similarity(self, v_emb, name='default', fake=False):
         | 
| 196 | 
            +
                    if fake:
         | 
| 197 | 
            +
                        return None
         | 
| 198 | 
            +
                    v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 199 | 
            +
                    t_emb = getattr(self, '{}_text_embeddings'.format(name))
         | 
| 200 | 
            +
                    output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose(1, 2)
         | 
| 201 | 
            +
                    return output
         | 
| 202 | 
            +
             | 
| 203 | 
            +
             | 
| 204 | 
            +
            @register_model
         | 
| 205 | 
            +
            def get_language_model(cfg, **kwargs):
         | 
| 206 | 
            +
                return LanguageEncoder(cfg)
         | 
    	
        modeling/modules/__init__.py
    ADDED
    
    | @@ -0,0 +1,6 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .point_features import *
         | 
| 2 | 
            +
            from .position_encoding import *
         | 
| 3 | 
            +
            from .postprocessing import *
         | 
| 4 | 
            +
            from .attention import *
         | 
| 5 | 
            +
            from .criterion import *
         | 
| 6 | 
            +
            from .matcher import *
         | 
    	
        modeling/modules/attention.py
    ADDED
    
    | @@ -0,0 +1,487 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import warnings
         | 
| 2 | 
            +
            from typing import Optional, Tuple
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn as nn
         | 
| 6 | 
            +
            from torch import Tensor
         | 
| 7 | 
            +
            from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
         | 
| 8 | 
            +
            from torch.nn.parameter import Parameter
         | 
| 9 | 
            +
            from torch.overrides import has_torch_function, handle_torch_function
         | 
| 10 | 
            +
            from torch.nn.functional import pad, linear, softmax, dropout
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            def multi_head_attention_forward(
         | 
| 14 | 
            +
                query: Tensor,
         | 
| 15 | 
            +
                key: Tensor,
         | 
| 16 | 
            +
                value: Tensor,
         | 
| 17 | 
            +
                embed_dim_to_check: int,
         | 
| 18 | 
            +
                num_heads: int,
         | 
| 19 | 
            +
                in_proj_weight: Tensor,
         | 
| 20 | 
            +
                in_proj_bias: Tensor,
         | 
| 21 | 
            +
                bias_k: Optional[Tensor],
         | 
| 22 | 
            +
                bias_v: Optional[Tensor],
         | 
| 23 | 
            +
                add_zero_attn: bool,
         | 
| 24 | 
            +
                dropout_p: float,
         | 
| 25 | 
            +
                out_proj_weight: Tensor,
         | 
| 26 | 
            +
                out_proj_bias: Tensor,
         | 
| 27 | 
            +
                training: bool = True,
         | 
| 28 | 
            +
                key_padding_mask: Optional[Tensor] = None,
         | 
| 29 | 
            +
                need_weights: bool = True,
         | 
| 30 | 
            +
                attn_mask: Optional[Tensor] = None,
         | 
| 31 | 
            +
                use_separate_proj_weight: bool = False,
         | 
| 32 | 
            +
                q_proj_weight: Optional[Tensor] = None,
         | 
| 33 | 
            +
                k_proj_weight: Optional[Tensor] = None,
         | 
| 34 | 
            +
                v_proj_weight: Optional[Tensor] = None,
         | 
| 35 | 
            +
                static_k: Optional[Tensor] = None,
         | 
| 36 | 
            +
                static_v: Optional[Tensor] = None,
         | 
| 37 | 
            +
            ) -> Tuple[Tensor, Optional[Tensor]]:
         | 
| 38 | 
            +
                r"""
         | 
| 39 | 
            +
                Args:
         | 
| 40 | 
            +
                    query, key, value: map a query and a set of key-value pairs to an output.
         | 
| 41 | 
            +
                        See "Attention Is All You Need" for more details.
         | 
| 42 | 
            +
                    embed_dim_to_check: total dimension of the model.
         | 
| 43 | 
            +
                    num_heads: parallel attention heads.
         | 
| 44 | 
            +
                    in_proj_weight, in_proj_bias: input projection weight and bias.
         | 
| 45 | 
            +
                    bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
         | 
| 46 | 
            +
                    add_zero_attn: add a new batch of zeros to the key and
         | 
| 47 | 
            +
                                   value sequences at dim=1.
         | 
| 48 | 
            +
                    dropout_p: probability of an element to be zeroed.
         | 
| 49 | 
            +
                    out_proj_weight, out_proj_bias: the output projection weight and bias.
         | 
| 50 | 
            +
                    training: apply dropout if is ``True``.
         | 
| 51 | 
            +
                    key_padding_mask: if provided, specified padding elements in the key will
         | 
| 52 | 
            +
                        be ignored by the attention. This is an binary mask. When the value is True,
         | 
| 53 | 
            +
                        the corresponding value on the attention layer will be filled with -inf.
         | 
| 54 | 
            +
                    need_weights: output attn_output_weights.
         | 
| 55 | 
            +
                    attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
         | 
| 56 | 
            +
                        the batches while a 3D mask allows to specify a different mask for the entries of each batch.
         | 
| 57 | 
            +
                    use_separate_proj_weight: the function accept the proj. weights for query, key,
         | 
| 58 | 
            +
                        and value in different forms. If false, in_proj_weight will be used, which is
         | 
| 59 | 
            +
                        a combination of q_proj_weight, k_proj_weight, v_proj_weight.
         | 
| 60 | 
            +
                    q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
         | 
| 61 | 
            +
                    static_k, static_v: static key and value used for attention operators.
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
                Shape:
         | 
| 65 | 
            +
                    Inputs:
         | 
| 66 | 
            +
                    - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
         | 
| 67 | 
            +
                      the embedding dimension.
         | 
| 68 | 
            +
                    - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
         | 
| 69 | 
            +
                      the embedding dimension.
         | 
| 70 | 
            +
                    - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
         | 
| 71 | 
            +
                      the embedding dimension.
         | 
| 72 | 
            +
                    - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
         | 
| 73 | 
            +
                      If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
         | 
| 74 | 
            +
                      will be unchanged. If a BoolTensor is provided, the positions with the
         | 
| 75 | 
            +
                      value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
         | 
| 76 | 
            +
                    - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
         | 
| 77 | 
            +
                      3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
         | 
| 78 | 
            +
                      S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
         | 
| 79 | 
            +
                      positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
         | 
| 80 | 
            +
                      while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
         | 
| 81 | 
            +
                      are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
         | 
| 82 | 
            +
                      is provided, it will be added to the attention weight.
         | 
| 83 | 
            +
                    - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
         | 
| 84 | 
            +
                      N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
         | 
| 85 | 
            +
                    - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
         | 
| 86 | 
            +
                      N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    Outputs:
         | 
| 89 | 
            +
                    - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
         | 
| 90 | 
            +
                      E is the embedding dimension.
         | 
| 91 | 
            +
                    - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
         | 
| 92 | 
            +
                      L is the target sequence length, S is the source sequence length.
         | 
| 93 | 
            +
                """
         | 
| 94 | 
            +
                tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
         | 
| 95 | 
            +
                if has_torch_function(tens_ops):
         | 
| 96 | 
            +
                    return handle_torch_function(
         | 
| 97 | 
            +
                        multi_head_attention_forward,
         | 
| 98 | 
            +
                        tens_ops,
         | 
| 99 | 
            +
                        query,
         | 
| 100 | 
            +
                        key,
         | 
| 101 | 
            +
                        value,
         | 
| 102 | 
            +
                        embed_dim_to_check,
         | 
| 103 | 
            +
                        num_heads,
         | 
| 104 | 
            +
                        in_proj_weight,
         | 
| 105 | 
            +
                        in_proj_bias,
         | 
| 106 | 
            +
                        bias_k,
         | 
| 107 | 
            +
                        bias_v,
         | 
| 108 | 
            +
                        add_zero_attn,
         | 
| 109 | 
            +
                        dropout_p,
         | 
| 110 | 
            +
                        out_proj_weight,
         | 
| 111 | 
            +
                        out_proj_bias,
         | 
| 112 | 
            +
                        training=training,
         | 
| 113 | 
            +
                        key_padding_mask=key_padding_mask,
         | 
| 114 | 
            +
                        need_weights=need_weights,
         | 
| 115 | 
            +
                        attn_mask=attn_mask,
         | 
| 116 | 
            +
                        use_separate_proj_weight=use_separate_proj_weight,
         | 
| 117 | 
            +
                        q_proj_weight=q_proj_weight,
         | 
| 118 | 
            +
                        k_proj_weight=k_proj_weight,
         | 
| 119 | 
            +
                        v_proj_weight=v_proj_weight,
         | 
| 120 | 
            +
                        static_k=static_k,
         | 
| 121 | 
            +
                        static_v=static_v,
         | 
| 122 | 
            +
                    )
         | 
| 123 | 
            +
                tgt_len, bsz, embed_dim = query.size()
         | 
| 124 | 
            +
                assert embed_dim == embed_dim_to_check
         | 
| 125 | 
            +
                # allow MHA to have different sizes for the feature dimension
         | 
| 126 | 
            +
                assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                head_dim = embed_dim // num_heads
         | 
| 129 | 
            +
                assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
         | 
| 130 | 
            +
                scaling = float(head_dim) ** -0.5
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                if not use_separate_proj_weight:
         | 
| 133 | 
            +
                    if (query is key or torch.equal(query, key)) and (key is value or torch.equal(key, value)):
         | 
| 134 | 
            +
                        # self-attention
         | 
| 135 | 
            +
                        q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    elif key is value or torch.equal(key, value):
         | 
| 138 | 
            +
                        # encoder-decoder attention
         | 
| 139 | 
            +
                        # This is inline in_proj function with in_proj_weight and in_proj_bias
         | 
| 140 | 
            +
                        _b = in_proj_bias
         | 
| 141 | 
            +
                        _start = 0
         | 
| 142 | 
            +
                        _end = embed_dim
         | 
| 143 | 
            +
                        _w = in_proj_weight[_start:_end, :]
         | 
| 144 | 
            +
                        if _b is not None:
         | 
| 145 | 
            +
                            _b = _b[_start:_end]
         | 
| 146 | 
            +
                        q = linear(query, _w, _b)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                        if key is None:
         | 
| 149 | 
            +
                            assert value is None
         | 
| 150 | 
            +
                            k = None
         | 
| 151 | 
            +
                            v = None
         | 
| 152 | 
            +
                        else:
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                            # This is inline in_proj function with in_proj_weight and in_proj_bias
         | 
| 155 | 
            +
                            _b = in_proj_bias
         | 
| 156 | 
            +
                            _start = embed_dim
         | 
| 157 | 
            +
                            _end = None
         | 
| 158 | 
            +
                            _w = in_proj_weight[_start:, :]
         | 
| 159 | 
            +
                            if _b is not None:
         | 
| 160 | 
            +
                                _b = _b[_start:]
         | 
| 161 | 
            +
                            k, v = linear(key, _w, _b).chunk(2, dim=-1)
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    else:
         | 
| 164 | 
            +
                        # This is inline in_proj function with in_proj_weight and in_proj_bias
         | 
| 165 | 
            +
                        _b = in_proj_bias
         | 
| 166 | 
            +
                        _start = 0
         | 
| 167 | 
            +
                        _end = embed_dim
         | 
| 168 | 
            +
                        _w = in_proj_weight[_start:_end, :]
         | 
| 169 | 
            +
                        if _b is not None:
         | 
| 170 | 
            +
                            _b = _b[_start:_end]
         | 
| 171 | 
            +
                        q = linear(query, _w, _b)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                        # This is inline in_proj function with in_proj_weight and in_proj_bias
         | 
| 174 | 
            +
                        _b = in_proj_bias
         | 
| 175 | 
            +
                        _start = embed_dim
         | 
| 176 | 
            +
                        _end = embed_dim * 2
         | 
| 177 | 
            +
                        _w = in_proj_weight[_start:_end, :]
         | 
| 178 | 
            +
                        if _b is not None:
         | 
| 179 | 
            +
                            _b = _b[_start:_end]
         | 
| 180 | 
            +
                        k = linear(key, _w, _b)
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                        # This is inline in_proj function with in_proj_weight and in_proj_bias
         | 
| 183 | 
            +
                        _b = in_proj_bias
         | 
| 184 | 
            +
                        _start = embed_dim * 2
         | 
| 185 | 
            +
                        _end = None
         | 
| 186 | 
            +
                        _w = in_proj_weight[_start:, :]
         | 
| 187 | 
            +
                        if _b is not None:
         | 
| 188 | 
            +
                            _b = _b[_start:]
         | 
| 189 | 
            +
                        v = linear(value, _w, _b)
         | 
| 190 | 
            +
                else:
         | 
| 191 | 
            +
                    q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
         | 
| 192 | 
            +
                    len1, len2 = q_proj_weight_non_opt.size()
         | 
| 193 | 
            +
                    assert len1 == embed_dim and len2 == query.size(-1)
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
         | 
| 196 | 
            +
                    len1, len2 = k_proj_weight_non_opt.size()
         | 
| 197 | 
            +
                    assert len1 == embed_dim and len2 == key.size(-1)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
         | 
| 200 | 
            +
                    len1, len2 = v_proj_weight_non_opt.size()
         | 
| 201 | 
            +
                    assert len1 == embed_dim and len2 == value.size(-1)
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    if in_proj_bias is not None:
         | 
| 204 | 
            +
                        q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
         | 
| 205 | 
            +
                        k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim : (embed_dim * 2)])
         | 
| 206 | 
            +
                        v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2) :])
         | 
| 207 | 
            +
                    else:
         | 
| 208 | 
            +
                        q = linear(query, q_proj_weight_non_opt, in_proj_bias)
         | 
| 209 | 
            +
                        k = linear(key, k_proj_weight_non_opt, in_proj_bias)
         | 
| 210 | 
            +
                        v = linear(value, v_proj_weight_non_opt, in_proj_bias)
         | 
| 211 | 
            +
                q = q * scaling
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                if attn_mask is not None:
         | 
| 214 | 
            +
                    assert (
         | 
| 215 | 
            +
                        attn_mask.dtype == torch.float32
         | 
| 216 | 
            +
                        or attn_mask.dtype == torch.float64
         | 
| 217 | 
            +
                        or attn_mask.dtype == torch.float16
         | 
| 218 | 
            +
                        or attn_mask.dtype == torch.uint8
         | 
| 219 | 
            +
                        or attn_mask.dtype == torch.bool
         | 
| 220 | 
            +
                    ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(attn_mask.dtype)
         | 
| 221 | 
            +
                    if attn_mask.dtype == torch.uint8:
         | 
| 222 | 
            +
                        warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
         | 
| 223 | 
            +
                        attn_mask = attn_mask.to(torch.bool)
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    if attn_mask.dim() == 2:
         | 
| 226 | 
            +
                        attn_mask = attn_mask.unsqueeze(0)
         | 
| 227 | 
            +
                        if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
         | 
| 228 | 
            +
                            raise RuntimeError("The size of the 2D attn_mask is not correct.")
         | 
| 229 | 
            +
                    elif attn_mask.dim() == 3:
         | 
| 230 | 
            +
                        if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
         | 
| 231 | 
            +
                            raise RuntimeError("The size of the 3D attn_mask is not correct.")
         | 
| 232 | 
            +
                    else:
         | 
| 233 | 
            +
                        raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
         | 
| 234 | 
            +
                    # attn_mask's dim is 3 now.
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                # convert ByteTensor key_padding_mask to bool
         | 
| 237 | 
            +
                if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
         | 
| 238 | 
            +
                    warnings.warn(
         | 
| 239 | 
            +
                        "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
         | 
| 240 | 
            +
                    )
         | 
| 241 | 
            +
                    key_padding_mask = key_padding_mask.to(torch.bool)
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                if bias_k is not None and bias_v is not None:
         | 
| 244 | 
            +
                    if static_k is None and static_v is None:
         | 
| 245 | 
            +
                        k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
         | 
| 246 | 
            +
                        v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
         | 
| 247 | 
            +
                        if attn_mask is not None:
         | 
| 248 | 
            +
                            attn_mask = pad(attn_mask, (0, 1))
         | 
| 249 | 
            +
                        if key_padding_mask is not None:
         | 
| 250 | 
            +
                            key_padding_mask = pad(key_padding_mask, (0, 1))
         | 
| 251 | 
            +
                    else:
         | 
| 252 | 
            +
                        assert static_k is None, "bias cannot be added to static key."
         | 
| 253 | 
            +
                        assert static_v is None, "bias cannot be added to static value."
         | 
| 254 | 
            +
                else:
         | 
| 255 | 
            +
                    assert bias_k is None
         | 
| 256 | 
            +
                    assert bias_v is None
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
         | 
| 259 | 
            +
                if k is not None:
         | 
| 260 | 
            +
                    k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
         | 
| 261 | 
            +
                if v is not None:
         | 
| 262 | 
            +
                    v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                if static_k is not None:
         | 
| 265 | 
            +
                    assert static_k.size(0) == bsz * num_heads
         | 
| 266 | 
            +
                    assert static_k.size(2) == head_dim
         | 
| 267 | 
            +
                    k = static_k
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                if static_v is not None:
         | 
| 270 | 
            +
                    assert static_v.size(0) == bsz * num_heads
         | 
| 271 | 
            +
                    assert static_v.size(2) == head_dim
         | 
| 272 | 
            +
                    v = static_v
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                src_len = k.size(1)
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                if key_padding_mask is not None:
         | 
| 277 | 
            +
                    # assert key_padding_mask.size(0) == bsz
         | 
| 278 | 
            +
                    assert key_padding_mask.size(1) == src_len
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                if add_zero_attn:
         | 
| 281 | 
            +
                    src_len += 1
         | 
| 282 | 
            +
                    k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
         | 
| 283 | 
            +
                    v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
         | 
| 284 | 
            +
                    if attn_mask is not None:
         | 
| 285 | 
            +
                        attn_mask = pad(attn_mask, (0, 1))
         | 
| 286 | 
            +
                    if key_padding_mask is not None:
         | 
| 287 | 
            +
                        key_padding_mask = pad(key_padding_mask, (0, 1))
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                attn_output_weights = torch.bmm(q, k.transpose(1, 2))
         | 
| 290 | 
            +
                assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                if attn_mask is not None:
         | 
| 293 | 
            +
                    if attn_mask.dtype == torch.bool:
         | 
| 294 | 
            +
                        attn_output_weights.masked_fill_(attn_mask, float("-inf"))
         | 
| 295 | 
            +
                    else:
         | 
| 296 | 
            +
                        attn_output_weights += attn_mask
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                if key_padding_mask is not None:
         | 
| 299 | 
            +
                    attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
         | 
| 300 | 
            +
                    attn_output_weights = attn_output_weights.masked_fill(
         | 
| 301 | 
            +
                        key_padding_mask.unsqueeze(1),
         | 
| 302 | 
            +
                        float("-inf"),
         | 
| 303 | 
            +
                    )
         | 
| 304 | 
            +
                    attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                attn_output_weights = softmax(attn_output_weights, dim=-1)
         | 
| 307 | 
            +
                attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                attn_output = torch.bmm(attn_output_weights, v)
         | 
| 310 | 
            +
                assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         | 
| 311 | 
            +
                attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
         | 
| 312 | 
            +
                attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                if need_weights:
         | 
| 315 | 
            +
                    # average attention weights over heads
         | 
| 316 | 
            +
                    attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
         | 
| 317 | 
            +
                    return attn_output, attn_output_weights.sum(dim=1) / num_heads
         | 
| 318 | 
            +
                else:
         | 
| 319 | 
            +
                    return attn_output, None
         | 
| 320 | 
            +
             | 
| 321 | 
            +
             | 
| 322 | 
            +
            # This class exists solely for Transformer; it has an annotation stating
         | 
| 323 | 
            +
            # that bias is never None, which appeases TorchScript
         | 
| 324 | 
            +
            class _LinearWithBias(nn.Linear):
         | 
| 325 | 
            +
                bias: Tensor  # type: ignore
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                def __init__(self, in_features: int, out_features: int) -> None:
         | 
| 328 | 
            +
                    super().__init__(in_features, out_features, bias=True)  # type: ignore
         | 
| 329 | 
            +
             | 
| 330 | 
            +
             | 
| 331 | 
            +
            class MultiheadAttention(nn.Module):
         | 
| 332 | 
            +
                r"""Allows the model to jointly attend to information
         | 
| 333 | 
            +
                from different representation subspaces.
         | 
| 334 | 
            +
                See `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                .. math::
         | 
| 337 | 
            +
                    \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                Args:
         | 
| 342 | 
            +
                    embed_dim: total dimension of the model.
         | 
| 343 | 
            +
                    num_heads: parallel attention heads.
         | 
| 344 | 
            +
                    dropout: a Dropout layer on attn_output_weights. Default: 0.0.
         | 
| 345 | 
            +
                    bias: add bias as module parameter. Default: True.
         | 
| 346 | 
            +
                    add_bias_kv: add bias to the key and value sequences at dim=0.
         | 
| 347 | 
            +
                    add_zero_attn: add a new batch of zeros to the key and
         | 
| 348 | 
            +
                                   value sequences at dim=1.
         | 
| 349 | 
            +
                    kdim: total number of features in key. Default: None.
         | 
| 350 | 
            +
                    vdim: total number of features in value. Default: None.
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set
         | 
| 353 | 
            +
                to :attr:`embed_dim` such that query, key, and value have the same
         | 
| 354 | 
            +
                number of features.
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                Examples::
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                    >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
         | 
| 359 | 
            +
                    >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
         | 
| 360 | 
            +
                """
         | 
| 361 | 
            +
                bias_k: Optional[torch.Tensor]
         | 
| 362 | 
            +
                bias_v: Optional[torch.Tensor]
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
         | 
| 365 | 
            +
                    super(MultiheadAttention, self).__init__()
         | 
| 366 | 
            +
                    self.embed_dim = embed_dim
         | 
| 367 | 
            +
                    self.kdim = kdim if kdim is not None else embed_dim
         | 
| 368 | 
            +
                    self.vdim = vdim if vdim is not None else embed_dim
         | 
| 369 | 
            +
                    self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                    self.num_heads = num_heads
         | 
| 372 | 
            +
                    self.dropout = dropout
         | 
| 373 | 
            +
                    self.head_dim = embed_dim // num_heads
         | 
| 374 | 
            +
                    assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                    if self._qkv_same_embed_dim is False:
         | 
| 377 | 
            +
                        self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
         | 
| 378 | 
            +
                        self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
         | 
| 379 | 
            +
                        self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
         | 
| 380 | 
            +
                        self.register_parameter('in_proj_weight', None)
         | 
| 381 | 
            +
                    else:
         | 
| 382 | 
            +
                        self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
         | 
| 383 | 
            +
                        self.register_parameter('q_proj_weight', None)
         | 
| 384 | 
            +
                        self.register_parameter('k_proj_weight', None)
         | 
| 385 | 
            +
                        self.register_parameter('v_proj_weight', None)
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                    if bias:
         | 
| 388 | 
            +
                        self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
         | 
| 389 | 
            +
                    else:
         | 
| 390 | 
            +
                        self.register_parameter('in_proj_bias', None)
         | 
| 391 | 
            +
                    self.out_proj = _LinearWithBias(embed_dim, embed_dim)
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                    if add_bias_kv:
         | 
| 394 | 
            +
                        self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
         | 
| 395 | 
            +
                        self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
         | 
| 396 | 
            +
                    else:
         | 
| 397 | 
            +
                        self.bias_k = self.bias_v = None
         | 
| 398 | 
            +
             | 
| 399 | 
            +
                    self.add_zero_attn = add_zero_attn
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                    self._reset_parameters()
         | 
| 402 | 
            +
             | 
| 403 | 
            +
                def _reset_parameters(self):
         | 
| 404 | 
            +
                    if self._qkv_same_embed_dim:
         | 
| 405 | 
            +
                        xavier_uniform_(self.in_proj_weight)
         | 
| 406 | 
            +
                    else:
         | 
| 407 | 
            +
                        xavier_uniform_(self.q_proj_weight)
         | 
| 408 | 
            +
                        xavier_uniform_(self.k_proj_weight)
         | 
| 409 | 
            +
                        xavier_uniform_(self.v_proj_weight)
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                    if self.in_proj_bias is not None:
         | 
| 412 | 
            +
                        constant_(self.in_proj_bias, 0.)
         | 
| 413 | 
            +
                        constant_(self.out_proj.bias, 0.)
         | 
| 414 | 
            +
                    if self.bias_k is not None:
         | 
| 415 | 
            +
                        xavier_normal_(self.bias_k)
         | 
| 416 | 
            +
                    if self.bias_v is not None:
         | 
| 417 | 
            +
                        xavier_normal_(self.bias_v)
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                def __setstate__(self, state):
         | 
| 420 | 
            +
                    # Support loading old MultiheadAttention checkpoints generated by v1.1.0
         | 
| 421 | 
            +
                    if '_qkv_same_embed_dim' not in state:
         | 
| 422 | 
            +
                        state['_qkv_same_embed_dim'] = True
         | 
| 423 | 
            +
             | 
| 424 | 
            +
                    super(MultiheadAttention, self).__setstate__(state)
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
         | 
| 427 | 
            +
                            need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
         | 
| 428 | 
            +
                    r"""
         | 
| 429 | 
            +
                Args:
         | 
| 430 | 
            +
                    query, key, value: map a query and a set of key-value pairs to an output.
         | 
| 431 | 
            +
                        See "Attention Is All You Need" for more details.
         | 
| 432 | 
            +
                    key_padding_mask: if provided, specified padding elements in the key will
         | 
| 433 | 
            +
                        be ignored by the attention. When given a binary mask and a value is True,
         | 
| 434 | 
            +
                        the corresponding value on the attention layer will be ignored. When given
         | 
| 435 | 
            +
                        a byte mask and a value is non-zero, the corresponding value on the attention
         | 
| 436 | 
            +
                        layer will be ignored
         | 
| 437 | 
            +
                    need_weights: output attn_output_weights.
         | 
| 438 | 
            +
                    attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
         | 
| 439 | 
            +
                        the batches while a 3D mask allows to specify a different mask for the entries of each batch.
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                Shapes for inputs:
         | 
| 442 | 
            +
                    - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
         | 
| 443 | 
            +
                      the embedding dimension.
         | 
| 444 | 
            +
                    - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
         | 
| 445 | 
            +
                      the embedding dimension.
         | 
| 446 | 
            +
                    - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
         | 
| 447 | 
            +
                      the embedding dimension.
         | 
| 448 | 
            +
                    - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
         | 
| 449 | 
            +
                      If a ByteTensor is provided, the non-zero positions will be ignored while the position
         | 
| 450 | 
            +
                      with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
         | 
| 451 | 
            +
                      value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
         | 
| 452 | 
            +
                    - attn_mask: if a 2D mask: :math:`(L, S)` where L is the target sequence length, S is the
         | 
| 453 | 
            +
                      source sequence length.
         | 
| 454 | 
            +
             | 
| 455 | 
            +
                      If a 3D mask: :math:`(N\cdot\text{num\_heads}, L, S)` where N is the batch size, L is the target sequence
         | 
| 456 | 
            +
                      length, S is the source sequence length. ``attn_mask`` ensure that position i is allowed to attend
         | 
| 457 | 
            +
                      the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
         | 
| 458 | 
            +
                      while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
         | 
| 459 | 
            +
                      is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
         | 
| 460 | 
            +
                      is provided, it will be added to the attention weight.
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                Shapes for outputs:
         | 
| 463 | 
            +
                    - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
         | 
| 464 | 
            +
                      E is the embedding dimension.
         | 
| 465 | 
            +
                    - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
         | 
| 466 | 
            +
                      L is the target sequence length, S is the source sequence length.
         | 
| 467 | 
            +
                    """
         | 
| 468 | 
            +
                    if not self._qkv_same_embed_dim:
         | 
| 469 | 
            +
                        return multi_head_attention_forward(
         | 
| 470 | 
            +
                            query, key, value, self.embed_dim, self.num_heads,
         | 
| 471 | 
            +
                            self.in_proj_weight, self.in_proj_bias,
         | 
| 472 | 
            +
                            self.bias_k, self.bias_v, self.add_zero_attn,
         | 
| 473 | 
            +
                            self.dropout, self.out_proj.weight, self.out_proj.bias,
         | 
| 474 | 
            +
                            training=self.training,
         | 
| 475 | 
            +
                            key_padding_mask=key_padding_mask, need_weights=need_weights,
         | 
| 476 | 
            +
                            attn_mask=attn_mask, use_separate_proj_weight=True,
         | 
| 477 | 
            +
                            q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
         | 
| 478 | 
            +
                            v_proj_weight=self.v_proj_weight)
         | 
| 479 | 
            +
                    else:
         | 
| 480 | 
            +
                        return multi_head_attention_forward(
         | 
| 481 | 
            +
                            query, key, value, self.embed_dim, self.num_heads,
         | 
| 482 | 
            +
                            self.in_proj_weight, self.in_proj_bias,
         | 
| 483 | 
            +
                            self.bias_k, self.bias_v, self.add_zero_attn,
         | 
| 484 | 
            +
                            self.dropout, self.out_proj.weight, self.out_proj.bias,
         | 
| 485 | 
            +
                            training=self.training,
         | 
| 486 | 
            +
                            key_padding_mask=key_padding_mask, need_weights=need_weights,
         | 
| 487 | 
            +
                            attn_mask=attn_mask)
         | 
    	
        modeling/modules/criterion.py
    ADDED
    
    | @@ -0,0 +1,874 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # --------------------------------------------------------
         | 
| 2 | 
            +
            # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
         | 
| 3 | 
            +
            # Copyright (c) 2022 Microsoft
         | 
| 4 | 
            +
            # Licensed under The MIT License [see LICENSE for details]
         | 
| 5 | 
            +
            # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
         | 
| 6 | 
            +
            # --------------------------------------------------------
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         | 
| 9 | 
            +
            # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/detr.py
         | 
| 10 | 
            +
            """
         | 
| 11 | 
            +
            MaskFormer criterion.
         | 
| 12 | 
            +
            """
         | 
| 13 | 
            +
            import logging
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            import torch
         | 
| 16 | 
            +
            import torch.nn.functional as F
         | 
| 17 | 
            +
            from torch import nn
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from detectron2.utils.comm import get_world_size
         | 
| 20 | 
            +
            from timm.loss import SoftTargetCrossEntropy
         | 
| 21 | 
            +
            from .point_features import (
         | 
| 22 | 
            +
                get_uncertain_point_coords_with_randomness,
         | 
| 23 | 
            +
                point_sample,
         | 
| 24 | 
            +
            )
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            from ..language.loss import ql_multi_contrastive_loss, image_text_contrastive_loss_queue, vl_similarity, all_gather_grad
         | 
| 27 | 
            +
            from ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list, _max_by_axis
         | 
| 28 | 
            +
            from ..utils import box_ops
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            # from image2html.visualizer import VL
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            def dice_loss(
         | 
| 34 | 
            +
                    inputs: torch.Tensor,
         | 
| 35 | 
            +
                    targets: torch.Tensor,
         | 
| 36 | 
            +
                    num_masks: float,
         | 
| 37 | 
            +
                ):
         | 
| 38 | 
            +
                """
         | 
| 39 | 
            +
                Compute the DICE loss, similar to generalized IOU for masks
         | 
| 40 | 
            +
                Args:
         | 
| 41 | 
            +
                    inputs: A float tensor of arbitrary shape.
         | 
| 42 | 
            +
                            The predictions for each example.
         | 
| 43 | 
            +
                    targets: A float tensor with the same shape as inputs. Stores the binary
         | 
| 44 | 
            +
                             classification label for each element in inputs
         | 
| 45 | 
            +
                            (0 for the negative class and 1 for the positive class).
         | 
| 46 | 
            +
                """
         | 
| 47 | 
            +
                inputs = inputs.sigmoid()
         | 
| 48 | 
            +
                inputs = inputs.flatten(1)
         | 
| 49 | 
            +
                numerator = 2 * (inputs * targets).sum(-1)
         | 
| 50 | 
            +
                denominator = inputs.sum(-1) + targets.sum(-1)
         | 
| 51 | 
            +
                loss = 1 - (numerator + 1) / (denominator + 1)
         | 
| 52 | 
            +
                return loss.sum() / num_masks
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            dice_loss_jit = torch.jit.script(
         | 
| 56 | 
            +
                dice_loss
         | 
| 57 | 
            +
            )  # type: torch.jit.ScriptModule
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            def sigmoid_ce_loss(
         | 
| 61 | 
            +
                    inputs: torch.Tensor,
         | 
| 62 | 
            +
                    targets: torch.Tensor,
         | 
| 63 | 
            +
                    num_masks: float,
         | 
| 64 | 
            +
                ):
         | 
| 65 | 
            +
                """
         | 
| 66 | 
            +
                Args:
         | 
| 67 | 
            +
                    inputs: A float tensor of arbitrary shape.
         | 
| 68 | 
            +
                            The predictions for each example.
         | 
| 69 | 
            +
                    targets: A float tensor with the same shape as inputs. Stores the binary
         | 
| 70 | 
            +
                             classification label for each element in inputs
         | 
| 71 | 
            +
                            (0 for the negative class and 1 for the positive class).
         | 
| 72 | 
            +
                Returns:
         | 
| 73 | 
            +
                    Loss tensor
         | 
| 74 | 
            +
                """
         | 
| 75 | 
            +
                loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                return loss.mean(1).sum() / num_masks
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            sigmoid_ce_loss_jit = torch.jit.script(
         | 
| 81 | 
            +
                sigmoid_ce_loss
         | 
| 82 | 
            +
            )  # type: torch.jit.ScriptModule
         | 
| 83 | 
            +
             | 
| 84 | 
            +
             | 
| 85 | 
            +
            def calculate_uncertainty(logits):
         | 
| 86 | 
            +
                """
         | 
| 87 | 
            +
                We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
         | 
| 88 | 
            +
                    foreground class in `classes`.
         | 
| 89 | 
            +
                Args:
         | 
| 90 | 
            +
                    logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or
         | 
| 91 | 
            +
                        class-agnostic, where R is the total number of predicted masks in all images and C is
         | 
| 92 | 
            +
                        the number of foreground classes. The values are logits.
         | 
| 93 | 
            +
                Returns:
         | 
| 94 | 
            +
                    scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
         | 
| 95 | 
            +
                        the most uncertain locations having the highest uncertainty score.
         | 
| 96 | 
            +
                """
         | 
| 97 | 
            +
                assert logits.shape[1] == 1
         | 
| 98 | 
            +
                gt_class_logits = logits.clone()
         | 
| 99 | 
            +
                return -(torch.abs(gt_class_logits))
         | 
| 100 | 
            +
             | 
| 101 | 
            +
             | 
| 102 | 
            +
            class SetCriterion(nn.Module):
         | 
| 103 | 
            +
                """This class computes the loss for DETR.
         | 
| 104 | 
            +
                The process happens in two steps:
         | 
| 105 | 
            +
                    1) we compute hungarian assignment between ground truth boxes and the outputs of the model
         | 
| 106 | 
            +
                    2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
         | 
| 107 | 
            +
                """
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                def __init__(self, num_classes, matcher, weight_dict, eos_coef, top_x_layers, losses,
         | 
| 110 | 
            +
                             num_points, oversample_ratio, importance_sample_ratio, grounding_weight):
         | 
| 111 | 
            +
                    """Create the criterion.
         | 
| 112 | 
            +
                    Parameters:
         | 
| 113 | 
            +
                        num_classes: number of object categories, omitting the special no-object category
         | 
| 114 | 
            +
                        matcher: module able to compute a matching between targets and proposals
         | 
| 115 | 
            +
                        weight_dict: dict containing as key the names of the losses and as values their relative weight.
         | 
| 116 | 
            +
                        eos_coef: relative classification weight applied to the no-object category
         | 
| 117 | 
            +
                        losses: list of all the losses to be applied. See get_loss for list of available losses.
         | 
| 118 | 
            +
                    """
         | 
| 119 | 
            +
                    super().__init__()
         | 
| 120 | 
            +
                    self.num_classes = num_classes
         | 
| 121 | 
            +
                    self.matcher = matcher
         | 
| 122 | 
            +
                    self.weight_dict = weight_dict
         | 
| 123 | 
            +
                    self.eos_coef = eos_coef
         | 
| 124 | 
            +
                    self.top_x_layers = top_x_layers
         | 
| 125 | 
            +
                    self.losses = losses
         | 
| 126 | 
            +
                    empty_weight = torch.ones(self.num_classes + 1)
         | 
| 127 | 
            +
                    empty_weight[-1] = self.eos_coef
         | 
| 128 | 
            +
                    self.register_buffer("empty_weight", empty_weight)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    # pointwise mask loss parameters
         | 
| 131 | 
            +
                    self.num_points = num_points
         | 
| 132 | 
            +
                    self.oversample_ratio = oversample_ratio
         | 
| 133 | 
            +
                    self.importance_sample_ratio = importance_sample_ratio
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    # grounding
         | 
| 136 | 
            +
                    self.grounding_weight = grounding_weight
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                def loss_labels(self, outputs, targets, indices, num_masks, layer_id, extra):
         | 
| 139 | 
            +
                    """Classification loss (NLL)
         | 
| 140 | 
            +
                    targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
         | 
| 141 | 
            +
                    """
         | 
| 142 | 
            +
                    if layer_id > self.top_x_layers['mask']:
         | 
| 143 | 
            +
                        return {"loss_mask_ce_0": 0}
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    if indices is None or len(targets) == 0:
         | 
| 146 | 
            +
                        loss_ce = outputs['pred_logits'].sum() * 0.0
         | 
| 147 | 
            +
                        losses = {"loss_mask_ce_0": loss_ce}
         | 
| 148 | 
            +
                        return losses
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    assert "pred_logits" in outputs
         | 
| 151 | 
            +
                    src_logits = outputs["pred_logits"].type(self.empty_weight.dtype)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    idx = self._get_src_permutation_idx(indices)
         | 
| 154 | 
            +
                    target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
         | 
| 155 | 
            +
                    target_classes = torch.full(
         | 
| 156 | 
            +
                        src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
         | 
| 157 | 
            +
                    )
         | 
| 158 | 
            +
                    target_classes[idx] = target_classes_o
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    if src_logits.shape[2] == self.num_classes+1:
         | 
| 161 | 
            +
                        empty_weight = torch.ones(self.num_classes + 1).to(src_logits.device).type(self.empty_weight.dtype)
         | 
| 162 | 
            +
                        empty_weight[-1] = self.eos_coef
         | 
| 163 | 
            +
                    else:
         | 
| 164 | 
            +
                        empty_weight = torch.ones(self.num_classes + 1000 + 1).to(src_logits.device).type(self.empty_weight.dtype)
         | 
| 165 | 
            +
                        empty_weight[self.num_classes] = self.eos_coef
         | 
| 166 | 
            +
                    loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes)
         | 
| 167 | 
            +
                    losses = {"loss_mask_ce_0": loss_ce}
         | 
| 168 | 
            +
                    return losses
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                def loss_labels_openimage(self, outputs, targets, indices, num_masks, layer_id, extra):
         | 
| 171 | 
            +
                    """Classification loss (NLL)
         | 
| 172 | 
            +
                    targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
         | 
| 173 | 
            +
                    """
         | 
| 174 | 
            +
                    if layer_id > self.top_x_layers['mask']:
         | 
| 175 | 
            +
                        return {"loss_openimage_ce_0": 0}
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    assert "pred_captions" in outputs
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    if indices is None or len(targets) == 0 or (len(targets) > 0 and len(targets[0]['labels']) == 0):
         | 
| 180 | 
            +
                        loss_ce = outputs['pred_captions'].sum() * 0.0
         | 
| 181 | 
            +
                        losses = {"loss_openimage_ce_0": loss_ce}
         | 
| 182 | 
            +
                        return losses
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    # compute i2t loss
         | 
| 185 | 
            +
                    loss_openimage_ce = 0
         | 
| 186 | 
            +
                    losses = {}
         | 
| 187 | 
            +
                    for b in range(len(indices)):
         | 
| 188 | 
            +
                        pred_logit = outputs["pred_logits"][b][indices[b][0]]
         | 
| 189 | 
            +
                        gt_logit = torch.zeros_like(pred_logit)
         | 
| 190 | 
            +
                        select_idx = torch.stack((torch.arange(len(indices[b][1])), indices[b][1])).tolist()
         | 
| 191 | 
            +
                        gt_logit[select_idx] = 1
         | 
| 192 | 
            +
                        loss_openimage_ce += torch.sum(-gt_logit * F.log_softmax(pred_logit, dim=-1), dim=-1).mean()
         | 
| 193 | 
            +
                    loss_openimage_ce = loss_openimage_ce / len(indices)
         | 
| 194 | 
            +
                    losses.update({"loss_openimage_ce_0": loss_openimage_ce})
         | 
| 195 | 
            +
                    return losses
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                def loss_itc(self, outputs, targets, indices, num_masks, layer_id, extra):
         | 
| 198 | 
            +
                    if layer_id >= self.top_x_layers['retrieval']:
         | 
| 199 | 
            +
                        return {"loss_retrieval_decoder_0": 0}
         | 
| 200 | 
            +
                    t_emb = torch.cat([x['caption_proj'] for x in targets], dim=0)
         | 
| 201 | 
            +
                    v_emb = outputs['pred_captions'][:,-1]
         | 
| 202 | 
            +
                    loss_contrast = image_text_contrastive_loss_queue(v_emb, t_emb, extra['lang_encoder'], extra['training'])
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                    # compute query-token contrastive loss
         | 
| 205 | 
            +
                    ttk_emb = torch.cat([x['caption_tokens'] for x in targets], dim=0)
         | 
| 206 | 
            +
                    ttk_mask = torch.cat([x['caption_mask'] for x in targets], dim=0).float()
         | 
| 207 | 
            +
                    ttk_mask = ttk_mask * torch.cumsum(ttk_mask, dim=1)
         | 
| 208 | 
            +
                    vtk_emb = outputs['pred_captions'][:,:-1]
         | 
| 209 | 
            +
                    keep = torch.cat([x['caption_mask'] for x in targets], dim=0).bool()
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    ttk_emb = ttk_emb / (ttk_emb.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 212 | 
            +
                    vtk_emb = vtk_emb / (vtk_emb.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 213 | 
            +
                    logit_scale = extra['lang_encoder'].logit_scale.exp().clamp(max=100)
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    # prepare gt
         | 
| 216 | 
            +
                    gt = (torch.eye(vtk_emb.shape[0]).type_as(ttk_mask).unsqueeze(-1) * ttk_mask.unsqueeze(0).repeat(vtk_emb.shape[0], 1, 1))[:,keep].flatten(1)
         | 
| 217 | 
            +
                    gt = gt / (gt.sum(1, keepdim=True) + 1e-7)
         | 
| 218 | 
            +
                    # compute i2t loss
         | 
| 219 | 
            +
                    logits = logit_scale * (vtk_emb @ ttk_emb[keep].transpose(0, 1)).mean(1)
         | 
| 220 | 
            +
                    loss_contrast_fine_vt = SoftTargetCrossEntropy()(logits, gt)
         | 
| 221 | 
            +
                    # loss_contrast_fine = loss_contrast_fine_vt # i2t only
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                    # compute t2i loss
         | 
| 224 | 
            +
                    bs, nq, _ = vtk_emb.shape
         | 
| 225 | 
            +
                    logits = logit_scale * (ttk_emb @ vtk_emb.flatten(0,1).transpose(0, 1)).reshape(bs,-1,bs,nq).mean(dim=-1)[keep,:]
         | 
| 226 | 
            +
                    loss_contrast_fine_tv = SoftTargetCrossEntropy()(logits, gt.t())
         | 
| 227 | 
            +
                    # compute loss
         | 
| 228 | 
            +
                    loss_contrast_fine = (loss_contrast_fine_vt * 0.7 + loss_contrast_fine_tv * 0.3)
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    losses = {"loss_retrieval_decoder_0": loss_contrast + loss_contrast_fine * 0.5}
         | 
| 231 | 
            +
                    return losses
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                def loss_captionings(self, outputs, targets, indices, num_masks, layer_id, extra):
         | 
| 234 | 
            +
                    if layer_id >= self.top_x_layers['captioning']:
         | 
| 235 | 
            +
                        return {"loss_captioning_0": 0}
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    pred_captions_gen = outputs['pred_captionings'][:, :-1]
         | 
| 238 | 
            +
                    token_embs = extra['token_embedding'].weight
         | 
| 239 | 
            +
                    # token_embs = (token_embs / token_embs.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 240 | 
            +
                    # pred_captions_gen = (pred_captions_gen / pred_captions_gen.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 241 | 
            +
                    pred_captions_gen = pred_captions_gen @ token_embs.t()
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                    # temperature = extra['lang_encoder'].logit_scale
         | 
| 244 | 
            +
                    # logit_scale = temperature.exp().clamp(max=100)
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    target_captions_gen = torch.cat([target['caption_tokenids'] for target in targets], 0)[:, 1:]
         | 
| 247 | 
            +
                    target_captions_gen_mask = torch.cat([target['caption_mask'] for target in targets], 0)[:, 1:]
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                    # loss_caption = F.cross_entropy(pred_captions_gen.transpose(1,2) * logit_scale, target_captions_gen, reduction='none')
         | 
| 250 | 
            +
                    loss_caption = F.cross_entropy(pred_captions_gen.transpose(1,2), target_captions_gen, reduction='none')
         | 
| 251 | 
            +
                    loss_caption = (loss_caption * target_captions_gen_mask).sum() / (target_captions_gen_mask.sum() + 1)
         | 
| 252 | 
            +
                    losses = {"loss_captioning_0": loss_caption}
         | 
| 253 | 
            +
                    return losses
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                def loss_captions(self, outputs, targets, indices, num_masks, layer_id, extra):
         | 
| 256 | 
            +
                    if layer_id >= self.top_x_layers['caption']:
         | 
| 257 | 
            +
                        return {"loss_caption_0": 0}
         | 
| 258 | 
            +
                    matched_tokens = [m[0] for m in indices]
         | 
| 259 | 
            +
                    t_emb_class = torch.cat([extra['class_embeddings'][targets[bs]['labels'][m[1]]] for bs, m in enumerate(indices)])    
         | 
| 260 | 
            +
                    t_hash_class = torch.cat([torch.tensor(targets[bs]['labels_hash'])[m[1]] for bs, m in enumerate(indices)])
         | 
| 261 | 
            +
                    
         | 
| 262 | 
            +
                    # pred_captions denotes all unmatched object queries.
         | 
| 263 | 
            +
                    unmatched_pred_captions = []
         | 
| 264 | 
            +
                    matched_pred_captions = []
         | 
| 265 | 
            +
                    for idx, m in enumerate(matched_tokens):
         | 
| 266 | 
            +
                        unmatched_masks = torch.ones(outputs['pred_captions'].shape[1:-1]).bool()
         | 
| 267 | 
            +
                        matched_masks = torch.zeros(outputs['pred_captions'].shape[1:-1]).bool()
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                        unmatched_masks[m] = False
         | 
| 270 | 
            +
                        matched_masks[m] = True
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                        unmatched_pred_captions.append(outputs['pred_captions'][idx][unmatched_masks])
         | 
| 273 | 
            +
                        matched_pred_captions.append(outputs['pred_captions'][idx][matched_masks])
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                    outputs['unmatched_pred_captions'] = unmatched_pred_captions
         | 
| 276 | 
            +
                    v_emb_class = torch.cat(matched_pred_captions)
         | 
| 277 | 
            +
                    v_emb_class = v_emb_class / (v_emb_class.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                    indices = self.matcher(outputs, targets, mode="caption_womask", extra={'temperature':extra['lang_logit']})
         | 
| 280 | 
            +
                    src_idx = self._get_src_permutation_idx(indices)
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    t_emb = torch.cat([t['captions'][indices[bs][1]] for bs,t in enumerate(targets)])
         | 
| 283 | 
            +
                    t_hash = torch.cat([torch.tensor(t['captions_hash'])[indices[bs][1]] for bs,t in enumerate(targets)])
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                    unmatched_pred_captions, _ = nested_tensor_from_tensor_list(unmatched_pred_captions).decompose()
         | 
| 286 | 
            +
                    v_emb = unmatched_pred_captions[src_idx]
         | 
| 287 | 
            +
                    v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 288 | 
            +
                    
         | 
| 289 | 
            +
                    loss_contrast = ql_multi_contrastive_loss(torch.cat((v_emb, v_emb_class)), torch.cat((t_emb, t_emb_class)), torch.cat((t_hash, t_hash_class)), temperature=extra['lang_logit'])
         | 
| 290 | 
            +
                    losses = {"loss_caption_0": loss_contrast}
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                    return losses
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                def loss_masks(self, outputs, targets, indices, num_masks, layer_id, extra):
         | 
| 295 | 
            +
                    """Compute the losses related to the masks: the focal loss and the dice loss.
         | 
| 296 | 
            +
                    targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
         | 
| 297 | 
            +
                    """
         | 
| 298 | 
            +
                    if layer_id >= self.top_x_layers['mask']:
         | 
| 299 | 
            +
                        return {"loss_mask_bce_0": 0, "loss_mask_dice_0": 0}
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                    assert "pred_masks" in outputs
         | 
| 302 | 
            +
                    if indices is None or len(targets) == 0:
         | 
| 303 | 
            +
                        loss = outputs['pred_masks'].sum() * 0.0
         | 
| 304 | 
            +
                        losses = {"loss_mask_bce_0": loss, "loss_mask_dice_0": loss}
         | 
| 305 | 
            +
                        return losses
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    src_idx = self._get_src_permutation_idx(indices)
         | 
| 308 | 
            +
                    tgt_idx = self._get_tgt_permutation_idx(indices)
         | 
| 309 | 
            +
                    src_masks = outputs["pred_masks"]
         | 
| 310 | 
            +
                    src_masks = src_masks[src_idx]
         | 
| 311 | 
            +
                    masks = [t["masks"] for t in targets]
         | 
| 312 | 
            +
                    # TODO use valid to mask invalid areas due to padding in loss
         | 
| 313 | 
            +
                    target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
         | 
| 314 | 
            +
                    target_masks = target_masks.to(src_masks)
         | 
| 315 | 
            +
                    target_masks = target_masks[tgt_idx]
         | 
| 316 | 
            +
                    # No need to upsample predictions as we are using normalized coordinates :)
         | 
| 317 | 
            +
                    # N x 1 x H x W
         | 
| 318 | 
            +
                    src_masks = src_masks[:, None]
         | 
| 319 | 
            +
                    target_masks = target_masks[:, None]
         | 
| 320 | 
            +
                    
         | 
| 321 | 
            +
                    with torch.no_grad():
         | 
| 322 | 
            +
                        # sample point_coords
         | 
| 323 | 
            +
                        point_coords = get_uncertain_point_coords_with_randomness(
         | 
| 324 | 
            +
                            src_masks,
         | 
| 325 | 
            +
                            lambda logits: calculate_uncertainty(logits),
         | 
| 326 | 
            +
                            self.num_points,
         | 
| 327 | 
            +
                            self.oversample_ratio,
         | 
| 328 | 
            +
                            self.importance_sample_ratio,
         | 
| 329 | 
            +
                        ).type(src_masks.dtype)
         | 
| 330 | 
            +
                        # get gt labels
         | 
| 331 | 
            +
                        point_labels = point_sample(
         | 
| 332 | 
            +
                            target_masks,
         | 
| 333 | 
            +
                            point_coords,
         | 
| 334 | 
            +
                            align_corners=False,
         | 
| 335 | 
            +
                        ).squeeze(1)
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                    point_logits = point_sample(
         | 
| 338 | 
            +
                        src_masks,
         | 
| 339 | 
            +
                        point_coords,
         | 
| 340 | 
            +
                        align_corners=False,
         | 
| 341 | 
            +
                    ).squeeze(1)
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                    losses = {
         | 
| 344 | 
            +
                        "loss_mask_bce_0": sigmoid_ce_loss_jit(point_logits, point_labels, num_masks),
         | 
| 345 | 
            +
                        "loss_mask_dice_0": dice_loss_jit(point_logits, point_labels, num_masks),
         | 
| 346 | 
            +
                    }
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                    del src_masks
         | 
| 349 | 
            +
                    del target_masks
         | 
| 350 | 
            +
                    return losses
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                def loss_groundings(self, outputs, targets, indices, num_masks, layer_id, extra):
         | 
| 353 | 
            +
                    """Compute the losses related to the masks: the focal loss and the dice loss.
         | 
| 354 | 
            +
                    targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
         | 
| 355 | 
            +
                    """
         | 
| 356 | 
            +
                    assert "pred_gmasks" in outputs
         | 
| 357 | 
            +
                    assert "pred_gtexts" in outputs
         | 
| 358 | 
            +
                    
         | 
| 359 | 
            +
                    if layer_id >= self.top_x_layers['grounding']:
         | 
| 360 | 
            +
                        return {"loss_grounding_bce_0": 0, "loss_grounding_dice_0": 0, "loss_grounding_ce_0": 0}
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                    masks = [t["grounding_masks"] for t in targets]
         | 
| 363 | 
            +
                    if indices is None or None in masks:
         | 
| 364 | 
            +
                        loss = outputs['pred_gmasks'].sum() * 0.0
         | 
| 365 | 
            +
                        return {"loss_grounding_bce_0": loss, "loss_grounding_dice_0": loss, "loss_grounding_ce_0": loss}
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                    pred_logits = []
         | 
| 368 | 
            +
                    for b in range(len(indices)):
         | 
| 369 | 
            +
                        t_emb = targets[b]['grounding_class_embs']
         | 
| 370 | 
            +
                        v_emb = outputs["pred_gtexts"][b]
         | 
| 371 | 
            +
                        
         | 
| 372 | 
            +
                        t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 373 | 
            +
                        v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)            
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                        out_prob = vl_similarity(v_emb, t_emb, temperature=extra['lang_logit'])
         | 
| 376 | 
            +
                        pred_logits += [out_prob]            
         | 
| 377 | 
            +
                    outputs['pred_logits'] = pred_logits
         | 
| 378 | 
            +
             | 
| 379 | 
            +
                    indices = self.matcher(outputs, targets, mode='grounding', extra={'temperature':extra['lang_logit']})
         | 
| 380 | 
            +
                    src_idx = self._get_src_permutation_idx(indices)
         | 
| 381 | 
            +
                    tgt_idx = self._get_tgt_permutation_idx(indices)
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                    src_masks = outputs["pred_gmasks"]
         | 
| 384 | 
            +
                    src_masks = src_masks[src_idx]
         | 
| 385 | 
            +
                    # TODO use valid to mask invalid areas due to padding in loss
         | 
| 386 | 
            +
                    target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
         | 
| 387 | 
            +
                    target_masks = target_masks.to(src_masks)
         | 
| 388 | 
            +
                    target_masks = target_masks[tgt_idx]
         | 
| 389 | 
            +
                    # No need to upsample predictions as we are using normalized coordinates :)
         | 
| 390 | 
            +
                    # N x 1 x H x W
         | 
| 391 | 
            +
                    src_masks = src_masks[:, None]
         | 
| 392 | 
            +
                    target_masks = target_masks[:, None]
         | 
| 393 | 
            +
                    
         | 
| 394 | 
            +
                    with torch.no_grad():
         | 
| 395 | 
            +
                        # sample point_coords
         | 
| 396 | 
            +
                        point_coords = get_uncertain_point_coords_with_randomness(
         | 
| 397 | 
            +
                            src_masks,
         | 
| 398 | 
            +
                            lambda logits: calculate_uncertainty(logits),
         | 
| 399 | 
            +
                            self.num_points,
         | 
| 400 | 
            +
                            self.oversample_ratio,
         | 
| 401 | 
            +
                            self.importance_sample_ratio,
         | 
| 402 | 
            +
                        ).type(src_masks.dtype)
         | 
| 403 | 
            +
                        # get gt labels
         | 
| 404 | 
            +
                        point_labels = point_sample(
         | 
| 405 | 
            +
                            target_masks,
         | 
| 406 | 
            +
                            point_coords,
         | 
| 407 | 
            +
                            align_corners=False,
         | 
| 408 | 
            +
                        ).squeeze(1)
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                    point_logits = point_sample(
         | 
| 411 | 
            +
                        src_masks,
         | 
| 412 | 
            +
                        point_coords,
         | 
| 413 | 
            +
                        align_corners=False,
         | 
| 414 | 
            +
                    ).squeeze(1)
         | 
| 415 | 
            +
             | 
| 416 | 
            +
                    losses = {
         | 
| 417 | 
            +
                        "loss_grounding_bce_0": sigmoid_ce_loss_jit(point_logits, point_labels, len(src_masks)),
         | 
| 418 | 
            +
                        "loss_grounding_dice_0": dice_loss_jit(point_logits, point_labels, len(src_masks)),
         | 
| 419 | 
            +
                    }
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                    # compute query-token contrastive loss
         | 
| 422 | 
            +
                    # ttk_emb = torch.cat([x['caption_tokens'] for x in targets], dim=0)
         | 
| 423 | 
            +
                    # ttk_mask = torch.cat([x['caption_mask'] for x in targets], dim=0).float()
         | 
| 424 | 
            +
                    # ttk_mask = ttk_mask * torch.cumsum(ttk_mask, dim=1)
         | 
| 425 | 
            +
                    # vtk_emb = outputs['pred_captions'][:,:-1]
         | 
| 426 | 
            +
                    # keep = torch.cat([x['caption_mask'] for x in targets], dim=0).bool()
         | 
| 427 | 
            +
             | 
| 428 | 
            +
                    # ttk_emb = ttk_emb / (ttk_emb.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 429 | 
            +
                    # vtk_emb = vtk_emb / (vtk_emb.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 430 | 
            +
                    # logit_scale = extra['lang_encoder'].logit_scale.exp().clamp(max=100)
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                    # # prepare gt
         | 
| 433 | 
            +
                    # gt = (torch.eye(vtk_emb.shape[0]).type_as(ttk_mask).unsqueeze(-1) * ttk_mask.unsqueeze(0).repeat(vtk_emb.shape[0], 1, 1))[:,keep].flatten(1)
         | 
| 434 | 
            +
                    # gt = gt / (gt.sum(1, keepdim=True) + 1e-7)
         | 
| 435 | 
            +
                    # # compute i2t loss
         | 
| 436 | 
            +
                    # logits = logit_scale * (vtk_emb @ ttk_emb[keep].transpose(0, 1)).mean(1)
         | 
| 437 | 
            +
                    # loss_contrast_fine_vt = SoftTargetCrossEntropy()(logits, gt)
         | 
| 438 | 
            +
                    # # loss_contrast_fine = loss_contrast_fine_vt # i2t only
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                    # # compute t2i loss
         | 
| 441 | 
            +
                    # bs, nq, _ = vtk_emb.shape
         | 
| 442 | 
            +
                    # logits = logit_scale * (ttk_emb @ vtk_emb.flatten(0,1).transpose(0, 1)).reshape(bs,-1,bs,nq).mean(dim=-1)[keep,:]
         | 
| 443 | 
            +
                    # loss_contrast_fine_tv = SoftTargetCrossEntropy()(logits, gt.t())
         | 
| 444 | 
            +
                    # # compute loss
         | 
| 445 | 
            +
                    # loss_contrast_fine = (loss_contrast_fine_vt * 0.7 + loss_contrast_fine_tv * 0.3)
         | 
| 446 | 
            +
             | 
| 447 | 
            +
                    # compute t2i loss
         | 
| 448 | 
            +
                    loss_grd_ce = 0
         | 
| 449 | 
            +
                    for b in range(len(indices)):
         | 
| 450 | 
            +
                        task = targets[b]['grounding_task']
         | 
| 451 | 
            +
                        pred_logit = outputs["pred_logits"][b]
         | 
| 452 | 
            +
                        gt_logit = torch.zeros_like(pred_logit)
         | 
| 453 | 
            +
                        select_idx = torch.stack((indices[b][0], indices[b][1])).tolist()
         | 
| 454 | 
            +
                        gt_logit[select_idx] = 1
         | 
| 455 | 
            +
                        t_hash = torch.tensor(targets[b]['grounding_hash'], device=gt_logit.device)
         | 
| 456 | 
            +
                        hash_table = torch.zeros((len(t_hash), len(t_hash)), device=gt_logit.device)
         | 
| 457 | 
            +
                        for idx in range(0, len(hash_table)):
         | 
| 458 | 
            +
                            hash_table[idx][t_hash==t_hash[idx]] = 1
         | 
| 459 | 
            +
                        hash_table = hash_table / hash_table.sum(-1, keepdim=True)
         | 
| 460 | 
            +
                        gt_logit = gt_logit @ hash_table
         | 
| 461 | 
            +
                        loss_grd_ce += self.grounding_weight[task]*torch.sum(-gt_logit.t() * F.log_softmax(pred_logit.t(), dim=-1), dim=-1).mean()
         | 
| 462 | 
            +
                    loss_grd_ce = loss_grd_ce / len(indices)
         | 
| 463 | 
            +
                    losses.update({"loss_grounding_ce_0": loss_grd_ce})
         | 
| 464 | 
            +
                    del src_masks
         | 
| 465 | 
            +
                    del target_masks
         | 
| 466 | 
            +
                    return losses
         | 
| 467 | 
            +
             | 
| 468 | 
            +
                def loss_spatials(self, outputs, targets, indices, num_masks, layer_id, extra):
         | 
| 469 | 
            +
                    """Compute the losses related to the masks: the focal loss and the dice loss.
         | 
| 470 | 
            +
                    targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
         | 
| 471 | 
            +
                    """
         | 
| 472 | 
            +
                    assert "pred_smasks" in outputs
         | 
| 473 | 
            +
                    assert "pred_smaskembs" in outputs
         | 
| 474 | 
            +
             | 
| 475 | 
            +
                    if layer_id >= self.top_x_layers['spatial']:
         | 
| 476 | 
            +
                        loss = outputs['pred_smasks'].sum() * 0.0
         | 
| 477 | 
            +
                        loss_grd_ce = outputs["pred_smasks"].sum() * 0.0
         | 
| 478 | 
            +
                        return {"loss_spatial_bce_0": loss, "loss_spatial_dice_0": loss, "loss_spatial_ce_0": loss_grd_ce}
         | 
| 479 | 
            +
             | 
| 480 | 
            +
                    gt_masks = [x['gt_spatial_masks'] for x in targets]
         | 
| 481 | 
            +
                    # compute a keep index with batch size to avoid empty gt_masks
         | 
| 482 | 
            +
                    stack_gt_mask = torch.cat(gt_masks)
         | 
| 483 | 
            +
                    bs,_,_ = stack_gt_mask.shape
         | 
| 484 | 
            +
                    stack_gt_mask = stack_gt_mask.view(bs,-1).sum(dim=-1)
         | 
| 485 | 
            +
                    keep = stack_gt_mask > 0 # only keep sample contain positive mask
         | 
| 486 | 
            +
             | 
| 487 | 
            +
                    if keep.sum() == 0:
         | 
| 488 | 
            +
                        loss = outputs['pred_smasks'].sum() * 0.0
         | 
| 489 | 
            +
                        loss_grd_ce = outputs["pred_smasks"].sum() * 0.0
         | 
| 490 | 
            +
                        return {"loss_spatial_bce_0": loss, "loss_spatial_dice_0": loss, "loss_spatial_ce_0": loss_grd_ce}
         | 
| 491 | 
            +
             | 
| 492 | 
            +
                    # mask embedding logits
         | 
| 493 | 
            +
                    v_emb = outputs["pred_smaskembs"] # [bs, nq, 512]
         | 
| 494 | 
            +
             | 
| 495 | 
            +
                    # pos mask
         | 
| 496 | 
            +
                    s_emb = outputs["pred_pspatials"] # [bs, ns, 512]
         | 
| 497 | 
            +
                    pred_logits = v_emb @ s_emb.transpose(1,2)
         | 
| 498 | 
            +
                    outputs['pred_pos_logits'] = pred_logits # [bs, nq, 1]
         | 
| 499 | 
            +
                    indices = self.matcher(outputs, targets, mode='spatial', extra={})
         | 
| 500 | 
            +
                    src_idx = self._get_src_permutation_idx(indices)
         | 
| 501 | 
            +
                    tgt_idx = self._get_tgt_permutation_idx(indices)
         | 
| 502 | 
            +
             | 
| 503 | 
            +
                    # pos class loss
         | 
| 504 | 
            +
                    pred_logit = torch.cat([o[:len(t['gt_spatial_masks'])] for o,t in zip(outputs["pred_pos_logits"].transpose(1,2), targets)])
         | 
| 505 | 
            +
                    gt_logit = torch.zeros_like(pred_logit)
         | 
| 506 | 
            +
                    gt_logit = gt_logit[keep]
         | 
| 507 | 
            +
                    _src_idx = [torch.arange(keep.sum(), device=src_idx[0].device), src_idx[1][keep.cpu()]]
         | 
| 508 | 
            +
                    gt_logit[_src_idx] = 1
         | 
| 509 | 
            +
                    pred_logit = pred_logit[keep]
         | 
| 510 | 
            +
                    loss_spa_ce_pos = torch.sum(-gt_logit * F.log_softmax(pred_logit, dim=-1), dim=-1).mean()
         | 
| 511 | 
            +
             | 
| 512 | 
            +
                    # neg mask
         | 
| 513 | 
            +
                    # s_emb = outputs["pred_nspatials"] # [bs, ns, 512]
         | 
| 514 | 
            +
                    # neg_mask = (s_emb.sum(dim=list(range(1, len(s_emb.shape)))) != 0).float()[keep]
         | 
| 515 | 
            +
                    # pred_logits = v_emb @ s_emb.transpose(1,2)
         | 
| 516 | 
            +
                    # outputs['pred_neg_logits'] = pred_logits # [bs, nq, 1]
         | 
| 517 | 
            +
                    # indices = self.matcher(outputs, targets, mode='spatial_pn', extra=extra)
         | 
| 518 | 
            +
                    # src_idx = self._get_src_permutation_idx(indices)
         | 
| 519 | 
            +
                    # tgt_idx = self._get_tgt_permutation_idx(indices)
         | 
| 520 | 
            +
                    # src_masks_neg = outputs["pred_smasks"][src_idx][keep]
         | 
| 521 | 
            +
                    # src_masks_neg = src_masks_neg*(neg_mask[:,None,None])
         | 
| 522 | 
            +
                    # src_masks_neg = src_masks_neg.clip(0) * (-1)
         | 
| 523 | 
            +
             | 
| 524 | 
            +
                    # neg class loss
         | 
| 525 | 
            +
                    # pred_logit = outputs["pred_neg_logits"]
         | 
| 526 | 
            +
                    # gt_logit = torch.zeros_like(pred_logit)
         | 
| 527 | 
            +
                    # gt_logit[src_idx] = 1
         | 
| 528 | 
            +
                    # bs,_,ns = pred_logit[keep].shape
         | 
| 529 | 
            +
                    # pred_logit = pred_logit[keep].transpose(1,2).view(bs*ns,-1)
         | 
| 530 | 
            +
                    # gt_logit = gt_logit[keep].transpose(1,2).view(bs*ns,-1)
         | 
| 531 | 
            +
                    # loss_spa_ce_neg = (torch.sum(-gt_logit * F.log_softmax(pred_logit, dim=-1), dim=-1)*neg_mask).sum() / (neg_mask.sum()+1e-6)
         | 
| 532 | 
            +
             | 
| 533 | 
            +
                    # recompute a keep index with matched tgt
         | 
| 534 | 
            +
                    stack_gt_mask = nn.utils.rnn.pad_sequence(gt_masks, padding_value=-1).transpose(0,1)[tgt_idx]        
         | 
| 535 | 
            +
                    bs,_,_ = stack_gt_mask.shape
         | 
| 536 | 
            +
                    target_masks = stack_gt_mask
         | 
| 537 | 
            +
                    stack_gt_mask = stack_gt_mask.view(bs,-1).sum(dim=-1)
         | 
| 538 | 
            +
                    keep = stack_gt_mask > 0 # only keep sample contain positive mask
         | 
| 539 | 
            +
                    src_masks_pos = outputs["pred_smasks"][src_idx][keep]
         | 
| 540 | 
            +
             | 
| 541 | 
            +
                    # TODO use valid to mask invalid areas due to padding in loss
         | 
| 542 | 
            +
                    target_masks = target_masks.to(src_masks_pos)
         | 
| 543 | 
            +
                    target_masks = target_masks[keep]
         | 
| 544 | 
            +
             | 
| 545 | 
            +
                    # mul = extra['spatial_query_mode'][keep]
         | 
| 546 | 
            +
                    # src_masks_cur = src_masks_cur.clip(0) * mul[:,None,None]
         | 
| 547 | 
            +
                    # src_masks_cur = src_masks_cur
         | 
| 548 | 
            +
             | 
| 549 | 
            +
                    # if neg_mask[0] == 1:
         | 
| 550 | 
            +
                    #     import cv2
         | 
| 551 | 
            +
                    #     print(src_masks_pos.shape)
         | 
| 552 | 
            +
                    #     print(src_masks_neg.shape)
         | 
| 553 | 
            +
                    #     print(target_masks.shape)
         | 
| 554 | 
            +
                    #     # import pdb; pdb.set_trace()
         | 
| 555 | 
            +
                    #     v_pos_mask = (src_masks_pos[0].sigmoid() > 0.5).float().cpu().detach().numpy() * 255
         | 
| 556 | 
            +
                    #     v_neg_mask = (_src_masks_neg[0].sigmoid() > 0.5).float().cpu().detach().numpy() * 255
         | 
| 557 | 
            +
                    #     v_sum = ((src_masks_pos[0]-_src_masks_neg[0].clip(0)).sigmoid() > 0.5).float().cpu().detach().numpy() * 255
         | 
| 558 | 
            +
                    #     v_gt = target_masks[0].float().cpu().detach().numpy() * 255
         | 
| 559 | 
            +
             | 
| 560 | 
            +
                    #     cv2.imwrite('v_pos_mask.png', v_pos_mask)
         | 
| 561 | 
            +
                    #     cv2.imwrite('v_neg_mask.png', v_neg_mask)
         | 
| 562 | 
            +
                    #     cv2.imwrite('v_sum.png', v_sum)
         | 
| 563 | 
            +
                    #     cv2.imwrite('v_gt.png', v_gt)
         | 
| 564 | 
            +
                    #     import pdb; pdb.set_trace()
         | 
| 565 | 
            +
             | 
| 566 | 
            +
                    # src_masks = (src_masks_pos + src_masks_neg)[:, None]
         | 
| 567 | 
            +
                    src_masks = src_masks_pos[:, None]
         | 
| 568 | 
            +
                    target_masks = target_masks[:, None]
         | 
| 569 | 
            +
             | 
| 570 | 
            +
                    # debug visualization
         | 
| 571 | 
            +
                    # with torch.no_grad():
         | 
| 572 | 
            +
                    #     import cv2
         | 
| 573 | 
            +
                    #     import numpy as np
         | 
| 574 | 
            +
             | 
| 575 | 
            +
                    #     v_src_masks = (F.interpolate(src_masks, size=target_masks.shape[-2:], mode='bilinear', align_corners=False).sigmoid() > 0.5).float().cpu().numpy()[:,0] * 255
         | 
| 576 | 
            +
                    #     v_target_masks = target_masks.float().cpu().numpy()[:,0] * 255
         | 
| 577 | 
            +
                    #     v_masks = np.concatenate([v_src_masks, v_target_masks], axis=2)
         | 
| 578 | 
            +
             | 
| 579 | 
            +
                    #     for i in range(len(src_masks)):
         | 
| 580 | 
            +
                    #         v1 = v_src_masks[i]
         | 
| 581 | 
            +
                    #         v2 = v_target_masks[i]
         | 
| 582 | 
            +
                    #         v = np.concatenate([v1,v2], axis=1)
         | 
| 583 | 
            +
                    #         cv2.imwrite('v{}.png'.format(i), v)
         | 
| 584 | 
            +
                    #     import pdb; pdb.set_trace()
         | 
| 585 | 
            +
             | 
| 586 | 
            +
                    # visualization
         | 
| 587 | 
            +
                    # VL.step()
         | 
| 588 | 
            +
                    # v_img = batched_inputs[0]['image'].permute(1,2,0).cpu().numpy()
         | 
| 589 | 
            +
                    # VL.add_image(v_img[:,:,::-1])
         | 
| 590 | 
            +
                    # candidate_masks = batched_inputs[0]['spatial_query']['rand_shape'].float().cpu().numpy()
         | 
| 591 | 
            +
                    # gt_masks = batched_inputs[0]['spatial_query']['gt_masks'].float().cpu().numpy()
         | 
| 592 | 
            +
                    # texts = ['cmask' for i in range(len(candidate_masks))]
         | 
| 593 | 
            +
                    # VL.overlay_obj_mask_to_image(v_img[:,:,::-1], candidate_masks, texts)
         | 
| 594 | 
            +
                    # texts = ['gmask' for i in range(len(candidate_masks))]
         | 
| 595 | 
            +
                    # VL.overlay_obj_mask_to_image(v_img[:,:,::-1], gt_masks, texts)
         | 
| 596 | 
            +
             | 
| 597 | 
            +
                    # import cv2
         | 
| 598 | 
            +
                    # for i in range(len(src_masks)):
         | 
| 599 | 
            +
                    #     visual_src_mask_cur = (src_masks_cur[i].sigmoid()>0.5).detach().float().cpu().numpy() * 255
         | 
| 600 | 
            +
                    #     visual_src_mask_mem = (src_masks_mem[i].sigmoid()>0.5).detach().float().cpu().numpy() * 255
         | 
| 601 | 
            +
                    #     visual_src_mask = (src_masks[i,0].sigmoid()>0.5).detach().float().cpu().numpy() * 255
         | 
| 602 | 
            +
                    #     visual_target_mask = (target_masks[i,0].sigmoid()>0.5).detach().float().cpu().numpy() * 255
         | 
| 603 | 
            +
             | 
| 604 | 
            +
                    #     cv2.imwrite('visual_src_mask_cur_{}_{}.png'.format(i, mul[i].item()), visual_src_mask_cur)
         | 
| 605 | 
            +
                    #     cv2.imwrite('visual_src_mask_mem_{}_{}.png'.format(i, mul[i].item()), visual_src_mask_mem)
         | 
| 606 | 
            +
                    #     cv2.imwrite('visual_src_mask_{}_{}.png'.format(i, mul[i].item()), visual_src_mask)
         | 
| 607 | 
            +
                    #     cv2.imwrite('visual_target_mask_{}_{}.png'.format(i, mul[i].item()), visual_target_mask)
         | 
| 608 | 
            +
                    # import pdb; pdb.set_trace()
         | 
| 609 | 
            +
             | 
| 610 | 
            +
                    with torch.no_grad():
         | 
| 611 | 
            +
                        # sample point_coords
         | 
| 612 | 
            +
                        point_coords = get_uncertain_point_coords_with_randomness(
         | 
| 613 | 
            +
                            src_masks,
         | 
| 614 | 
            +
                            lambda logits: calculate_uncertainty(logits),
         | 
| 615 | 
            +
                            self.num_points,
         | 
| 616 | 
            +
                            self.oversample_ratio,
         | 
| 617 | 
            +
                            self.importance_sample_ratio,
         | 
| 618 | 
            +
                        ).type(src_masks.dtype)
         | 
| 619 | 
            +
                        # get gt labels
         | 
| 620 | 
            +
                        point_labels = point_sample(
         | 
| 621 | 
            +
                            target_masks,
         | 
| 622 | 
            +
                            point_coords,
         | 
| 623 | 
            +
                            align_corners=False,
         | 
| 624 | 
            +
                        ).squeeze(1)
         | 
| 625 | 
            +
             | 
| 626 | 
            +
                    point_logits = point_sample(
         | 
| 627 | 
            +
                        src_masks,
         | 
| 628 | 
            +
                        point_coords,
         | 
| 629 | 
            +
                        align_corners=False,
         | 
| 630 | 
            +
                    ).squeeze(1)
         | 
| 631 | 
            +
             | 
| 632 | 
            +
                    num_masks = len(src_masks)
         | 
| 633 | 
            +
                    losses = {
         | 
| 634 | 
            +
                        "loss_spatial_bce_0": sigmoid_ce_loss_jit(point_logits, point_labels, num_masks),
         | 
| 635 | 
            +
                        "loss_spatial_dice_0": dice_loss_jit(point_logits, point_labels, num_masks),
         | 
| 636 | 
            +
                    }
         | 
| 637 | 
            +
             | 
| 638 | 
            +
                    # losses.update({"loss_spatial_ce_0": loss_spa_ce_pos + loss_spa_ce_neg})
         | 
| 639 | 
            +
                    losses.update({"loss_spatial_ce_0": loss_spa_ce_pos})
         | 
| 640 | 
            +
             | 
| 641 | 
            +
                    del src_masks
         | 
| 642 | 
            +
                    del target_masks
         | 
| 643 | 
            +
                    return losses
         | 
| 644 | 
            +
             | 
| 645 | 
            +
                def loss_boxes(self, outputs, targets, indices, num_boxes, layer_id, extra):
         | 
| 646 | 
            +
                    """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
         | 
| 647 | 
            +
                       targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
         | 
| 648 | 
            +
                       The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
         | 
| 649 | 
            +
                    """
         | 
| 650 | 
            +
                    if layer_id >= self.top_x_layers['box']:
         | 
| 651 | 
            +
                        return {"loss_bbox_0": 0, "loss_giou_0": 0}
         | 
| 652 | 
            +
             | 
| 653 | 
            +
                    assert 'pred_boxes' in outputs
         | 
| 654 | 
            +
             | 
| 655 | 
            +
                    if indices is None or len(targets) == 0:
         | 
| 656 | 
            +
                        loss = outputs['pred_boxes'].sum() * 0.0
         | 
| 657 | 
            +
                        losses = {"loss_bbox_0": loss, "loss_giou_0": loss}
         | 
| 658 | 
            +
                        return losses
         | 
| 659 | 
            +
             | 
| 660 | 
            +
                    src_idx = self._get_src_permutation_idx(indices)
         | 
| 661 | 
            +
                    tgt_idx = self._get_tgt_permutation_idx(indices)
         | 
| 662 | 
            +
                    src_boxes = outputs["pred_boxes"]
         | 
| 663 | 
            +
                    src_boxes = src_boxes[src_idx].sigmoid()
         | 
| 664 | 
            +
                    
         | 
| 665 | 
            +
                    target_boxes = [t['boxes'] for t in targets]
         | 
| 666 | 
            +
                    max_size = _max_by_axis([list(box.shape) for box in target_boxes])
         | 
| 667 | 
            +
                    max_size = [len(target_boxes)] + max_size
         | 
| 668 | 
            +
                    empty_boxes = torch.zeros(max_size).to(src_boxes.device)
         | 
| 669 | 
            +
                    for idx, tar_box in enumerate(target_boxes):
         | 
| 670 | 
            +
                        empty_boxes[idx,:tar_box.shape[0],:] = tar_box
         | 
| 671 | 
            +
                    target_boxes = empty_boxes[tgt_idx]
         | 
| 672 | 
            +
             | 
| 673 | 
            +
                    # target_isthings = [t['is_things'] for t in targets]
         | 
| 674 | 
            +
                    # max_size = _max_by_axis([list(lab.shape) for lab in target_isthings])
         | 
| 675 | 
            +
                    # max_size = [len(target_isthings)] + max_size
         | 
| 676 | 
            +
                    # empty_lab = torch.zeros(max_size).to(src_boxes.device)
         | 
| 677 | 
            +
             | 
| 678 | 
            +
                    # for idx, tar_thing in enumerate(target_isthings):
         | 
| 679 | 
            +
                    #     empty_lab[idx,:tar_thing.shape[0]] = tar_thing
         | 
| 680 | 
            +
                    # target_isthings = empty_lab[tgt_idx]
         | 
| 681 | 
            +
             | 
| 682 | 
            +
                    loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
         | 
| 683 | 
            +
                    losses = {}
         | 
| 684 | 
            +
                    losses['loss_bbox_0'] = loss_bbox.sum() / num_boxes
         | 
| 685 | 
            +
                    
         | 
| 686 | 
            +
                    loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
         | 
| 687 | 
            +
                        box_ops.box_cxcywh_to_xyxy(src_boxes),
         | 
| 688 | 
            +
                        box_ops.box_cxcywh_to_xyxy(target_boxes)))
         | 
| 689 | 
            +
                    losses['loss_giou_0'] = loss_giou.sum() / num_boxes
         | 
| 690 | 
            +
                    return losses
         | 
| 691 | 
            +
             | 
| 692 | 
            +
                def _get_src_permutation_idx(self, indices):
         | 
| 693 | 
            +
                    # permute predictions following indices
         | 
| 694 | 
            +
                    batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
         | 
| 695 | 
            +
                    src_idx = torch.cat([src for (src, _) in indices])
         | 
| 696 | 
            +
                    return batch_idx, src_idx
         | 
| 697 | 
            +
             | 
| 698 | 
            +
                def _get_tgt_permutation_idx(self, indices):
         | 
| 699 | 
            +
                    # permute targets following indices
         | 
| 700 | 
            +
                    batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
         | 
| 701 | 
            +
                    tgt_idx = torch.cat([tgt for (_, tgt) in indices])
         | 
| 702 | 
            +
                    return batch_idx, tgt_idx
         | 
| 703 | 
            +
             | 
| 704 | 
            +
                def get_loss(self, loss, outputs, targets, indices, num_masks, layer_id, extra):
         | 
| 705 | 
            +
                    loss_map = {
         | 
| 706 | 
            +
                        'labels': self.loss_labels,
         | 
| 707 | 
            +
                        'masks': self.loss_masks,
         | 
| 708 | 
            +
                        'boxes': self.loss_boxes,
         | 
| 709 | 
            +
                        'captions': self.loss_captions,
         | 
| 710 | 
            +
                        'retrievals': self.loss_itc,
         | 
| 711 | 
            +
                        'captionings': self.loss_captionings,
         | 
| 712 | 
            +
                        'groundings': self.loss_groundings,
         | 
| 713 | 
            +
                        'labels_openimage': self.loss_labels_openimage,
         | 
| 714 | 
            +
                        'spatials': self.loss_spatials,
         | 
| 715 | 
            +
                    }
         | 
| 716 | 
            +
                    assert loss in loss_map, f"do you really want to compute {loss} loss?"
         | 
| 717 | 
            +
                    return loss_map[loss](outputs, targets, indices, num_masks, layer_id, extra)
         | 
| 718 | 
            +
             | 
| 719 | 
            +
                def forward(self, outputs, targets, extra=None):
         | 
| 720 | 
            +
                    """This performs the loss computation.
         | 
| 721 | 
            +
                    Parameters:
         | 
| 722 | 
            +
                         outputs: dict of tensors, see the output specification of the model for the format
         | 
| 723 | 
            +
                         targets: list of dicts, such that len(targets) == batch_size.
         | 
| 724 | 
            +
                                  The expected keys in each dict depends on the losses applied, see each loss' doc
         | 
| 725 | 
            +
                    """
         | 
| 726 | 
            +
                    outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
         | 
| 727 | 
            +
             | 
| 728 | 
            +
                    # Retrieve the matching between the outputs of the last layer and the targets
         | 
| 729 | 
            +
                    indices = self.matcher(outputs_without_aux, targets)
         | 
| 730 | 
            +
             | 
| 731 | 
            +
                    # Compute the average number of target boxes accross all nodes, for normalization purposes
         | 
| 732 | 
            +
                    num_masks = sum(len(t["labels"]) for t in targets)
         | 
| 733 | 
            +
                    num_masks = torch.as_tensor(
         | 
| 734 | 
            +
                        [num_masks], dtype=torch.float, device=next(iter(outputs_without_aux.values())).device
         | 
| 735 | 
            +
                    )
         | 
| 736 | 
            +
                    if is_dist_avail_and_initialized():
         | 
| 737 | 
            +
                        torch.distributed.all_reduce(num_masks)
         | 
| 738 | 
            +
                    num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
         | 
| 739 | 
            +
             | 
| 740 | 
            +
                    # Compute all the requested losses
         | 
| 741 | 
            +
                    losses = {}
         | 
| 742 | 
            +
                    for loss in self.losses:
         | 
| 743 | 
            +
                        losses.update(self.get_loss(loss, outputs, targets, indices, num_masks, 0, extra))
         | 
| 744 | 
            +
             | 
| 745 | 
            +
                    # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
         | 
| 746 | 
            +
                    if "aux_outputs" in outputs:
         | 
| 747 | 
            +
                        # NOTE: we reverse the aux_outputs so that the first is the second last layer
         | 
| 748 | 
            +
                        for i, aux_outputs in enumerate(outputs["aux_outputs"][::-1]):
         | 
| 749 | 
            +
                            indices = self.matcher(aux_outputs, targets)
         | 
| 750 | 
            +
                            for loss in self.losses:
         | 
| 751 | 
            +
                                l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks, (i+1), extra)
         | 
| 752 | 
            +
                                l_dict = {k.replace('_0', f"_{i+1}"): v for k, v in l_dict.items()}
         | 
| 753 | 
            +
                                losses.update(l_dict)
         | 
| 754 | 
            +
             | 
| 755 | 
            +
                    return losses
         | 
| 756 | 
            +
             | 
| 757 | 
            +
                def forward_vlp(self, outputs, targets, extra=None):
         | 
| 758 | 
            +
                    """This performs the loss computation.
         | 
| 759 | 
            +
                    Parameters:
         | 
| 760 | 
            +
                         outputs: dict of tensors, see the output specification of the model for the format
         | 
| 761 | 
            +
                         targets: list of dicts, such that len(targets) == batch_size.
         | 
| 762 | 
            +
                                  The expected keys in each dict depends on the losses applied, see each loss' doc
         | 
| 763 | 
            +
                    """
         | 
| 764 | 
            +
                    # Compute all the requested losses
         | 
| 765 | 
            +
                    losses = {}
         | 
| 766 | 
            +
                    num_masks = indices = None
         | 
| 767 | 
            +
                    for loss in self.losses:
         | 
| 768 | 
            +
                        losses.update(self.get_loss(loss, outputs, targets, indices, num_masks, 0, extra))
         | 
| 769 | 
            +
             | 
| 770 | 
            +
                    # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
         | 
| 771 | 
            +
                    if "aux_outputs" in outputs:
         | 
| 772 | 
            +
                        # NOTE: we reverse the aux_outputs so that the first is the second last layer
         | 
| 773 | 
            +
                        for i, aux_outputs in enumerate(outputs["aux_outputs"][::-1]):
         | 
| 774 | 
            +
                            for loss in self.losses:
         | 
| 775 | 
            +
                                l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks, (i+1), extra)
         | 
| 776 | 
            +
                                l_dict = {k.replace('_0', f"_{i+1}"): v for k, v in l_dict.items()}
         | 
| 777 | 
            +
                                losses.update(l_dict)
         | 
| 778 | 
            +
             | 
| 779 | 
            +
                    return losses
         | 
| 780 | 
            +
             | 
| 781 | 
            +
                def forward_grounding(self, outputs, targets, extra=None):
         | 
| 782 | 
            +
                    """This performs the loss computation.
         | 
| 783 | 
            +
                    Parameters:
         | 
| 784 | 
            +
                         outputs: dict of tensors, see the output specification of the model for the format
         | 
| 785 | 
            +
                         targets: list of dicts, such that len(targets) == batch_size.
         | 
| 786 | 
            +
                                  The expected keys in each dict depends on the losses applied, see each loss' doc
         | 
| 787 | 
            +
                    """
         | 
| 788 | 
            +
                    # Compute all the requested losses
         | 
| 789 | 
            +
                    losses = {}
         | 
| 790 | 
            +
                    indices = [[] for i in range(len(targets))]
         | 
| 791 | 
            +
             | 
| 792 | 
            +
                    # Compute the average number of target boxes accross all nodes, for normalization purposes
         | 
| 793 | 
            +
                    num_masks = sum(len(t["grounding_masks"]) for t in targets) + 1e-7
         | 
| 794 | 
            +
                    num_masks = torch.as_tensor(
         | 
| 795 | 
            +
                        [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device
         | 
| 796 | 
            +
                    )
         | 
| 797 | 
            +
                    if is_dist_avail_and_initialized():
         | 
| 798 | 
            +
                        torch.distributed.all_reduce(num_masks)
         | 
| 799 | 
            +
                    num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
         | 
| 800 | 
            +
             | 
| 801 | 
            +
                    for loss in self.losses:
         | 
| 802 | 
            +
                        losses.update(self.get_loss(loss, outputs, targets, indices, num_masks, 0, extra))
         | 
| 803 | 
            +
             | 
| 804 | 
            +
                    # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
         | 
| 805 | 
            +
                    if "aux_outputs" in outputs:
         | 
| 806 | 
            +
                        # NOTE: we reverse the aux_outputs so that the first is the second last layer
         | 
| 807 | 
            +
                        for i, aux_outputs in enumerate(outputs["aux_outputs"][::-1]):
         | 
| 808 | 
            +
                            for loss in self.losses:
         | 
| 809 | 
            +
                                l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks, (i+1), extra)
         | 
| 810 | 
            +
                                l_dict = {k.replace('_0', f"_{i+1}"): v for k, v in l_dict.items()}
         | 
| 811 | 
            +
                                losses.update(l_dict)
         | 
| 812 | 
            +
             | 
| 813 | 
            +
                    return losses
         | 
| 814 | 
            +
             | 
| 815 | 
            +
                def forward_openimage(self, outputs, targets, extra=None):
         | 
| 816 | 
            +
                    """This performs the loss computation.
         | 
| 817 | 
            +
                    Parameters:
         | 
| 818 | 
            +
                         outputs: dict of tensors, see the output specification of the model for the format
         | 
| 819 | 
            +
                         targets: list of dicts, such that len(targets) == batch_size.
         | 
| 820 | 
            +
                                  The expected keys in each dict depends on the losses applied, see each loss' doc
         | 
| 821 | 
            +
                    """
         | 
| 822 | 
            +
                    neg_class_emb =  all_gather_grad(torch.cat([x['neg_class_emb'] for x in targets]))
         | 
| 823 | 
            +
                    neg_hash = all_gather_grad(torch.cat([x['neg_hash'] for x in targets]))
         | 
| 824 | 
            +
             | 
| 825 | 
            +
                    extra['neg_class_emb'] = neg_class_emb
         | 
| 826 | 
            +
                    extra['neg_hash'] = neg_hash
         | 
| 827 | 
            +
                    outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
         | 
| 828 | 
            +
             | 
| 829 | 
            +
                    # Retrieve the matching between the outputs of the last layer and the targets
         | 
| 830 | 
            +
                    indices, pred_logits = self.matcher.openimage_forward(outputs_without_aux, targets, extra=extra)
         | 
| 831 | 
            +
                    outputs['pred_logits'] = pred_logits
         | 
| 832 | 
            +
             | 
| 833 | 
            +
                    # Compute the average number of target boxes accross all nodes, for normalization purposes
         | 
| 834 | 
            +
                    num_masks = sum(len(t["labels"]) for t in targets)
         | 
| 835 | 
            +
                    num_masks = torch.as_tensor(
         | 
| 836 | 
            +
                        [num_masks], dtype=torch.float, device=neg_class_emb.device
         | 
| 837 | 
            +
                    )
         | 
| 838 | 
            +
                    if is_dist_avail_and_initialized():
         | 
| 839 | 
            +
                        torch.distributed.all_reduce(num_masks)
         | 
| 840 | 
            +
                    num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
         | 
| 841 | 
            +
             | 
| 842 | 
            +
                    # Compute all the requested losses
         | 
| 843 | 
            +
                    losses = {}
         | 
| 844 | 
            +
                    for loss in self.losses:
         | 
| 845 | 
            +
                        losses.update(self.get_loss(loss, outputs, targets, indices, num_masks, 0, extra))
         | 
| 846 | 
            +
             | 
| 847 | 
            +
                    # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
         | 
| 848 | 
            +
                    if "aux_outputs" in outputs:
         | 
| 849 | 
            +
                        # NOTE: we reverse the aux_outputs so that the first is the second last layer
         | 
| 850 | 
            +
                        for i, aux_outputs in enumerate(outputs["aux_outputs"][::-1]):
         | 
| 851 | 
            +
                            indices, pred_logits = self.matcher.openimage_forward(aux_outputs, targets, extra=extra)
         | 
| 852 | 
            +
                            aux_outputs['pred_logits'] = pred_logits
         | 
| 853 | 
            +
                            for loss in self.losses:
         | 
| 854 | 
            +
                                l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks, (i+1), extra)
         | 
| 855 | 
            +
                                l_dict = {k.replace('_0', f"_{i+1}"): v for k, v in l_dict.items()}
         | 
| 856 | 
            +
                                losses.update(l_dict)
         | 
| 857 | 
            +
             | 
| 858 | 
            +
                    return losses
         | 
| 859 | 
            +
             | 
| 860 | 
            +
                def __repr__(self):
         | 
| 861 | 
            +
                    head = "Criterion " + self.__class__.__name__
         | 
| 862 | 
            +
                    body = [
         | 
| 863 | 
            +
                        "matcher: {}".format(self.matcher.__repr__(_repr_indent=8)),
         | 
| 864 | 
            +
                        "losses: {}".format(self.losses),
         | 
| 865 | 
            +
                        "weight_dict: {}".format(self.weight_dict),
         | 
| 866 | 
            +
                        "num_classes: {}".format(self.num_classes),
         | 
| 867 | 
            +
                        "eos_coef: {}".format(self.eos_coef),
         | 
| 868 | 
            +
                        "num_points: {}".format(self.num_points),
         | 
| 869 | 
            +
                        "oversample_ratio: {}".format(self.oversample_ratio),
         | 
| 870 | 
            +
                        "importance_sample_ratio: {}".format(self.importance_sample_ratio),
         | 
| 871 | 
            +
                    ]
         | 
| 872 | 
            +
                    _repr_indent = 4
         | 
| 873 | 
            +
                    lines = [head] + [" " * _repr_indent + line for line in body]
         | 
| 874 | 
            +
                    return "\n".join(lines)
         | 
    	
        modeling/modules/matcher.py
    ADDED
    
    | @@ -0,0 +1,632 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # --------------------------------------------------------
         | 
| 2 | 
            +
            # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
         | 
| 3 | 
            +
            # Copyright (c) 2022 Microsoft
         | 
| 4 | 
            +
            # Licensed under The MIT License [see LICENSE for details]
         | 
| 5 | 
            +
            # Modified by Xueyan Zou (xueyan@cs.wisc.edu)
         | 
| 6 | 
            +
            # --------------------------------------------------------
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         | 
| 9 | 
            +
            # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/matcher.py
         | 
| 10 | 
            +
            """
         | 
| 11 | 
            +
            Modules to compute the matching cost and solve the corresponding LSAP.
         | 
| 12 | 
            +
            """
         | 
| 13 | 
            +
            import warnings
         | 
| 14 | 
            +
            import torch
         | 
| 15 | 
            +
            import torch.nn.functional as F
         | 
| 16 | 
            +
            import numpy as np
         | 
| 17 | 
            +
            from scipy.optimize import linear_sum_assignment
         | 
| 18 | 
            +
            from torch import nn
         | 
| 19 | 
            +
            from torch.cuda.amp import autocast
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            from .point_features import point_sample    
         | 
| 22 | 
            +
            from ..language.loss import vl_similarity
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor):
         | 
| 25 | 
            +
                """
         | 
| 26 | 
            +
                Compute the DICE loss, similar to generalized IOU for masks
         | 
| 27 | 
            +
                Args:
         | 
| 28 | 
            +
                    inputs: A float tensor of arbitrary shape.
         | 
| 29 | 
            +
                            The predictions for each example.
         | 
| 30 | 
            +
                    targets: A float tensor with the same shape as inputs. Stores the binary
         | 
| 31 | 
            +
                             classification label for each element in inputs
         | 
| 32 | 
            +
                            (0 for the negative class and 1 for the positive class).
         | 
| 33 | 
            +
                """
         | 
| 34 | 
            +
                inputs = inputs.sigmoid()
         | 
| 35 | 
            +
                inputs = inputs.flatten(1)
         | 
| 36 | 
            +
                numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets)
         | 
| 37 | 
            +
                denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :]
         | 
| 38 | 
            +
                loss = 1 - (numerator + 1) / (denominator + 1)
         | 
| 39 | 
            +
                return loss
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            batch_dice_loss_jit = torch.jit.script(
         | 
| 43 | 
            +
                batch_dice_loss
         | 
| 44 | 
            +
            )  # type: torch.jit.ScriptModule
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor):
         | 
| 48 | 
            +
                """
         | 
| 49 | 
            +
                Args:
         | 
| 50 | 
            +
                    inputs: A float tensor of arbitrary shape.
         | 
| 51 | 
            +
                            The predictions for each example.
         | 
| 52 | 
            +
                    targets: A float tensor with the same shape as inputs. Stores the binary
         | 
| 53 | 
            +
                             classification label for each element in inputs
         | 
| 54 | 
            +
                            (0 for the negative class and 1 for the positive class).
         | 
| 55 | 
            +
                Returns:
         | 
| 56 | 
            +
                    Loss tensor
         | 
| 57 | 
            +
                """
         | 
| 58 | 
            +
                hw = inputs.shape[1]
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                pos = F.binary_cross_entropy_with_logits(
         | 
| 61 | 
            +
                    inputs, torch.ones_like(inputs), reduction="none"
         | 
| 62 | 
            +
                )
         | 
| 63 | 
            +
                neg = F.binary_cross_entropy_with_logits(
         | 
| 64 | 
            +
                    inputs, torch.zeros_like(inputs), reduction="none"
         | 
| 65 | 
            +
                )
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum(
         | 
| 68 | 
            +
                    "nc,mc->nm", neg, (1 - targets)
         | 
| 69 | 
            +
                )
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                return loss / hw
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            batch_sigmoid_ce_loss_jit = torch.jit.script(
         | 
| 75 | 
            +
                batch_sigmoid_ce_loss
         | 
| 76 | 
            +
            )  # type: torch.jit.ScriptModule
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            class HungarianMatcher(nn.Module):
         | 
| 80 | 
            +
                """This class computes an assignment between the targets and the predictions of the network
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                For efficiency reasons, the targets don't include the no_object. Because of this, in general,
         | 
| 83 | 
            +
                there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
         | 
| 84 | 
            +
                while the others are un-matched (and thus treated as non-objects).
         | 
| 85 | 
            +
                """
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1, num_points: int = 0, spatial_cost = None):
         | 
| 88 | 
            +
                    """Creates the matcher
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    Params:
         | 
| 91 | 
            +
                        cost_class: This is the relative weight of the classification error in the matching cost
         | 
| 92 | 
            +
                        cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost
         | 
| 93 | 
            +
                        cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost
         | 
| 94 | 
            +
                    """
         | 
| 95 | 
            +
                    super().__init__()
         | 
| 96 | 
            +
                    self.cost_class = cost_class
         | 
| 97 | 
            +
                    self.cost_mask = cost_mask
         | 
| 98 | 
            +
                    self.cost_dice = cost_dice
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    self.num_points = num_points
         | 
| 101 | 
            +
                    self.spatial_cost_class = cost_class
         | 
| 102 | 
            +
                    self.spatial_cost_mask = cost_mask
         | 
| 103 | 
            +
                    self.spatial_cost_dice = cost_dice
         | 
| 104 | 
            +
                    assert cost_class != 0 or cost_mask != 0 or cost_dice != 0, "all costs cant be 0"
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                @torch.no_grad()
         | 
| 107 | 
            +
                def memory_efficient_forward(self, outputs, targets):
         | 
| 108 | 
            +
                    """More memory-friendly matching"""
         | 
| 109 | 
            +
                    bs, num_queries = outputs["pred_logits"].shape[:2]
         | 
| 110 | 
            +
                    
         | 
| 111 | 
            +
                    if bs == 0 or len(targets) == 0:
         | 
| 112 | 
            +
                        return None
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    indices = []
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    # Iterate through batch size
         | 
| 117 | 
            +
                    for b in range(bs):
         | 
| 118 | 
            +
                        out_prob = outputs["pred_logits"][b].softmax(-1)  # [num_queries, num_classes]
         | 
| 119 | 
            +
                        tgt_ids = targets[b]["labels"]
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                        # Compute the classification cost. Contrary to the loss, we don't use the NLL,
         | 
| 122 | 
            +
                        # but approximate it in 1 - proba[target class].
         | 
| 123 | 
            +
                        # The 1 is a constant that doesn't change the matching, it can be ommitted.
         | 
| 124 | 
            +
                        cost_class = -out_prob[:, tgt_ids]
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                        out_mask = outputs["pred_masks"][b]  # [num_queries, H_pred, W_pred]
         | 
| 127 | 
            +
                        # gt masks are already padded when preparing target
         | 
| 128 | 
            +
                        tgt_mask = targets[b]["masks"].to(out_mask)
         | 
| 129 | 
            +
                        
         | 
| 130 | 
            +
                        out_mask = out_mask[:, None]
         | 
| 131 | 
            +
                        tgt_mask = tgt_mask[:, None]
         | 
| 132 | 
            +
                        # all masks share the same set of points for efficient matching!
         | 
| 133 | 
            +
                        point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)
         | 
| 134 | 
            +
                        # get gt labels
         | 
| 135 | 
            +
                        tgt_mask = point_sample(
         | 
| 136 | 
            +
                            tgt_mask,
         | 
| 137 | 
            +
                            point_coords.repeat(tgt_mask.shape[0], 1, 1),
         | 
| 138 | 
            +
                            align_corners=False,
         | 
| 139 | 
            +
                        ).squeeze(1)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                        out_mask = point_sample(
         | 
| 142 | 
            +
                            out_mask,
         | 
| 143 | 
            +
                            point_coords.repeat(out_mask.shape[0], 1, 1),
         | 
| 144 | 
            +
                            align_corners=False,
         | 
| 145 | 
            +
                        ).squeeze(1)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                        with autocast(enabled=False):
         | 
| 148 | 
            +
                            out_mask = out_mask.float()
         | 
| 149 | 
            +
                            tgt_mask = tgt_mask.float()
         | 
| 150 | 
            +
                            # Compute the focal loss between masks
         | 
| 151 | 
            +
                            cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                            # Compute the dice loss betwen masks
         | 
| 154 | 
            +
                            cost_dice = batch_dice_loss_jit(out_mask, tgt_mask)
         | 
| 155 | 
            +
                        
         | 
| 156 | 
            +
                        # Final cost matrix
         | 
| 157 | 
            +
                        C = (
         | 
| 158 | 
            +
                            self.cost_mask * cost_mask
         | 
| 159 | 
            +
                            + self.cost_class * cost_class
         | 
| 160 | 
            +
                            + self.cost_dice * cost_dice
         | 
| 161 | 
            +
                        )
         | 
| 162 | 
            +
                        C = C.reshape(num_queries, -1).cpu()
         | 
| 163 | 
            +
                        if C.isnan().any():
         | 
| 164 | 
            +
                            C[C.isnan()] = 1e6 ### temporary fix
         | 
| 165 | 
            +
                            warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
         | 
| 166 | 
            +
                            raise
         | 
| 167 | 
            +
                        indices.append(linear_sum_assignment(C))
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    return [
         | 
| 170 | 
            +
                        (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
         | 
| 171 | 
            +
                        for i, j in indices
         | 
| 172 | 
            +
                    ]
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                @torch.no_grad()
         | 
| 175 | 
            +
                def openimage_forward(self, outputs, targets, extra):
         | 
| 176 | 
            +
                    """More memory-friendly matching"""
         | 
| 177 | 
            +
                    bs, num_queries = outputs["pred_captions"].shape[:2]
         | 
| 178 | 
            +
                    if bs == 0 or len(targets) == 0:
         | 
| 179 | 
            +
                        return None
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    neg_class_emb = extra['neg_class_emb']
         | 
| 182 | 
            +
                    neg_hash = extra['neg_hash']
         | 
| 183 | 
            +
                    _, unique_indices = np.unique(neg_hash.cpu().numpy(), return_index=True)
         | 
| 184 | 
            +
                    neg_class_emb = neg_class_emb[unique_indices]
         | 
| 185 | 
            +
                    neg_hash = neg_hash[unique_indices]
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    indices = []
         | 
| 188 | 
            +
                    pred_logits = []
         | 
| 189 | 
            +
                    # Iterate through batch size
         | 
| 190 | 
            +
                    for b in range(bs):
         | 
| 191 | 
            +
                        _pos_class_emb = targets[b]['pos_class_emb']
         | 
| 192 | 
            +
                        _pos_hash = targets[b]['pos_hash']
         | 
| 193 | 
            +
                        _neg_overlap_pos = ~(neg_hash[..., None] == _pos_hash).any(-1)
         | 
| 194 | 
            +
                        _neg_class_emb = neg_class_emb[_neg_overlap_pos]
         | 
| 195 | 
            +
                        t_emb = torch.cat((_pos_class_emb, _neg_class_emb))
         | 
| 196 | 
            +
                        v_emb = outputs["pred_captions"][b]            
         | 
| 197 | 
            +
                        del _pos_class_emb
         | 
| 198 | 
            +
                        del _neg_class_emb
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                        t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 201 | 
            +
                        v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)            
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                        out_prob = vl_similarity(v_emb, t_emb, temperature=extra['lang_logit'])
         | 
| 204 | 
            +
                        pred_logits += [out_prob]
         | 
| 205 | 
            +
                        out_prob = out_prob.softmax(-1)
         | 
| 206 | 
            +
                        tgt_ids = targets[b]["labels"]
         | 
| 207 | 
            +
                        # Compute the classification cost. Contrary to the loss, we don't use the NLL,
         | 
| 208 | 
            +
                        # but approximate it in 1 - proba[target class].
         | 
| 209 | 
            +
                        # The 1 is a constant that doesn't change the matching, it can be ommitted.
         | 
| 210 | 
            +
                        cost_class = -out_prob[:, tgt_ids]
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                        out_mask = outputs["pred_masks"][b]  # [num_queries, H_pred, W_pred]
         | 
| 213 | 
            +
                        # gt masks are already padded when preparing target
         | 
| 214 | 
            +
                        tgt_mask = targets[b]["masks"].to(out_mask)
         | 
| 215 | 
            +
                        
         | 
| 216 | 
            +
                        out_mask = out_mask[:, None]
         | 
| 217 | 
            +
                        tgt_mask = tgt_mask[:, None]
         | 
| 218 | 
            +
                        # all masks share the same set of points for efficient matching!
         | 
| 219 | 
            +
                        point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)
         | 
| 220 | 
            +
                        # get gt labels
         | 
| 221 | 
            +
                        tgt_mask = point_sample(
         | 
| 222 | 
            +
                            tgt_mask,
         | 
| 223 | 
            +
                            point_coords.repeat(tgt_mask.shape[0], 1, 1),
         | 
| 224 | 
            +
                            align_corners=False,
         | 
| 225 | 
            +
                        ).squeeze(1)
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                        out_mask = point_sample(
         | 
| 228 | 
            +
                            out_mask,
         | 
| 229 | 
            +
                            point_coords.repeat(out_mask.shape[0], 1, 1),
         | 
| 230 | 
            +
                            align_corners=False,
         | 
| 231 | 
            +
                        ).squeeze(1)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                        with autocast(enabled=False):
         | 
| 234 | 
            +
                            out_mask = out_mask.float()
         | 
| 235 | 
            +
                            tgt_mask = tgt_mask.float()
         | 
| 236 | 
            +
                            # Compute the focal loss between masks
         | 
| 237 | 
            +
                            cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask)
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                            # Compute the dice loss betwen masks
         | 
| 240 | 
            +
                            cost_dice = batch_dice_loss_jit(out_mask, tgt_mask)
         | 
| 241 | 
            +
                        
         | 
| 242 | 
            +
                        # Final cost matrix
         | 
| 243 | 
            +
                        C = (
         | 
| 244 | 
            +
                            self.cost_mask * cost_mask
         | 
| 245 | 
            +
                            + self.cost_class * cost_class
         | 
| 246 | 
            +
                            + self.cost_dice * cost_dice
         | 
| 247 | 
            +
                        )
         | 
| 248 | 
            +
                        C = C.reshape(num_queries, -1).cpu()
         | 
| 249 | 
            +
                        if C.isnan().any():
         | 
| 250 | 
            +
                            C[C.isnan()] = 1e6 ### temporary fix
         | 
| 251 | 
            +
                            warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
         | 
| 252 | 
            +
                            raise
         | 
| 253 | 
            +
                        indices.append(linear_sum_assignment(C))
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                    return [
         | 
| 256 | 
            +
                        (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
         | 
| 257 | 
            +
                        for i, j in indices
         | 
| 258 | 
            +
                    ], pred_logits
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                @torch.no_grad()
         | 
| 261 | 
            +
                def grounding_forward(self, outputs, targets, extra):
         | 
| 262 | 
            +
                    """More memory-friendly matching"""
         | 
| 263 | 
            +
                    bs, num_queries = outputs["pred_gmasks"].shape[:2]
         | 
| 264 | 
            +
                    
         | 
| 265 | 
            +
                    if bs == 0 or len(targets) == 0:
         | 
| 266 | 
            +
                        return None
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                    indices = []
         | 
| 269 | 
            +
                    # Iterate through batch size
         | 
| 270 | 
            +
                    for b in range(bs):
         | 
| 271 | 
            +
                        out_prob = outputs["pred_logits"][b]
         | 
| 272 | 
            +
                        # Compute the classification cost. Contrary to the loss, we don't use the NLL,
         | 
| 273 | 
            +
                        # but approximate it in 1 - proba[target class].
         | 
| 274 | 
            +
                        # The 1 is a constant that doesn't change the matching, it can be ommitted.
         | 
| 275 | 
            +
                        cost_class = -out_prob.softmax(dim=0)
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                        out_mask = outputs["pred_gmasks"][b]  # [num_queries, H_pred, W_pred]
         | 
| 278 | 
            +
                        # gt masks are already padded when preparing target
         | 
| 279 | 
            +
                        tgt_mask = targets[b]["grounding_masks"].to(out_mask)
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                        out_mask = out_mask[:, None]
         | 
| 282 | 
            +
                        tgt_mask = tgt_mask[:, None]
         | 
| 283 | 
            +
                        
         | 
| 284 | 
            +
                        # all masks share the same set of points for efficient matching!
         | 
| 285 | 
            +
                        point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)
         | 
| 286 | 
            +
                        # get gt labels
         | 
| 287 | 
            +
                        tgt_mask = point_sample(
         | 
| 288 | 
            +
                            tgt_mask,
         | 
| 289 | 
            +
                            point_coords.repeat(tgt_mask.shape[0], 1, 1),
         | 
| 290 | 
            +
                            align_corners=False,
         | 
| 291 | 
            +
                        ).squeeze(1)
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                        out_mask = point_sample(
         | 
| 294 | 
            +
                            out_mask,
         | 
| 295 | 
            +
                            point_coords.repeat(out_mask.shape[0], 1, 1),
         | 
| 296 | 
            +
                            align_corners=False,
         | 
| 297 | 
            +
                        ).squeeze(1)
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                        with autocast(enabled=False):
         | 
| 300 | 
            +
                            out_mask = out_mask.float()
         | 
| 301 | 
            +
                            tgt_mask = tgt_mask.float()
         | 
| 302 | 
            +
                            # Compute the focal loss between masks
         | 
| 303 | 
            +
                            cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask)
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                            # Compute the dice loss betwen masks
         | 
| 306 | 
            +
                            cost_dice = batch_dice_loss_jit(out_mask, tgt_mask)
         | 
| 307 | 
            +
                            
         | 
| 308 | 
            +
                        # Final cost matrix
         | 
| 309 | 
            +
                        C = (
         | 
| 310 | 
            +
                            self.cost_mask * cost_mask
         | 
| 311 | 
            +
                            + self.cost_class * cost_class
         | 
| 312 | 
            +
                            + self.cost_dice * cost_dice
         | 
| 313 | 
            +
                        )
         | 
| 314 | 
            +
                        C = C.reshape(num_queries, -1).cpu()
         | 
| 315 | 
            +
                        if C.isnan().any():
         | 
| 316 | 
            +
                            C[C.isnan()] = 1e6 ### temporary fix
         | 
| 317 | 
            +
                            warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
         | 
| 318 | 
            +
                            raise
         | 
| 319 | 
            +
                        indices.append(linear_sum_assignment(C))
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                    return [
         | 
| 322 | 
            +
                        (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
         | 
| 323 | 
            +
                        for i, j in indices
         | 
| 324 | 
            +
                    ]
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                @torch.no_grad()
         | 
| 327 | 
            +
                def spatial_forward(self, outputs, targets, extra):
         | 
| 328 | 
            +
                    """More memory-friendly matching"""
         | 
| 329 | 
            +
                    bs, num_queries = outputs["pred_smasks"].shape[:2]
         | 
| 330 | 
            +
                    
         | 
| 331 | 
            +
                    if bs == 0 or len(targets) == 0:
         | 
| 332 | 
            +
                        return None
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                    indices = []
         | 
| 335 | 
            +
                    # Iterate through batch size
         | 
| 336 | 
            +
                    for b in range(bs):
         | 
| 337 | 
            +
                        out_mask = outputs["pred_smasks"][b]  # [num_queries, H_pred, W_pred]
         | 
| 338 | 
            +
                        # gt masks are already padded when preparing target
         | 
| 339 | 
            +
                        tgt_mask = targets[b]["gt_spatial_masks"].to(out_mask)
         | 
| 340 | 
            +
                        nd,ns = outputs["pred_pos_logits"][b].shape
         | 
| 341 | 
            +
                        index_masking = 1-torch.eye(ns, device=out_mask.device, dtype=tgt_mask.dtype).repeat_interleave(nd//ns,dim=0)
         | 
| 342 | 
            +
                        neg_masking = torch.zeros((nd,ns), device=out_mask.device, dtype=tgt_mask.dtype)
         | 
| 343 | 
            +
                        neg_masking.masked_fill_(index_masking.bool(), -float('inf'))
         | 
| 344 | 
            +
                        pos_masking = torch.zeros((nd,ns), device=out_mask.device, dtype=tgt_mask.dtype)
         | 
| 345 | 
            +
                        pos_masking.masked_fill_(index_masking.bool(), float('inf'))
         | 
| 346 | 
            +
                        out_prob = (outputs["pred_pos_logits"][b]+neg_masking)[:,:len(tgt_mask)] # remove redundant predictions for padding
         | 
| 347 | 
            +
                        # Compute the classification cost. Contrary to the loss, we don't use the NLL,
         | 
| 348 | 
            +
                        # but approximate it in 1 - proba[target class].
         | 
| 349 | 
            +
                        # The 1 is a constant that doesn't change the matching, it can be ommitted.
         | 
| 350 | 
            +
                        cost_class = -out_prob.softmax(dim=0)
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                        out_mask = out_mask[:, None]
         | 
| 353 | 
            +
                        tgt_mask = tgt_mask[:, None]
         | 
| 354 | 
            +
                        
         | 
| 355 | 
            +
                        # all masks share the same set of points for efficient matching!
         | 
| 356 | 
            +
                        point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)
         | 
| 357 | 
            +
                        # get gt labels
         | 
| 358 | 
            +
                        tgt_mask = point_sample(
         | 
| 359 | 
            +
                            tgt_mask,
         | 
| 360 | 
            +
                            point_coords.repeat(tgt_mask.shape[0], 1, 1),
         | 
| 361 | 
            +
                            align_corners=False,
         | 
| 362 | 
            +
                        ).squeeze(1)
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                        out_mask = point_sample(
         | 
| 365 | 
            +
                            out_mask,
         | 
| 366 | 
            +
                            point_coords.repeat(out_mask.shape[0], 1, 1),
         | 
| 367 | 
            +
                            align_corners=False,
         | 
| 368 | 
            +
                        ).squeeze(1)
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                        with autocast(enabled=False):
         | 
| 371 | 
            +
                            out_mask = out_mask.float()
         | 
| 372 | 
            +
                            tgt_mask = tgt_mask.float()
         | 
| 373 | 
            +
                            # Compute the focal loss between masks
         | 
| 374 | 
            +
                            cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask) + pos_masking[:,:len(tgt_mask)]
         | 
| 375 | 
            +
                            # Compute the dice loss betwen masks
         | 
| 376 | 
            +
                            cost_dice = batch_dice_loss_jit(out_mask, tgt_mask) + pos_masking[:,:len(tgt_mask)]
         | 
| 377 | 
            +
                        
         | 
| 378 | 
            +
                        # Final cost matrix
         | 
| 379 | 
            +
                        C = (
         | 
| 380 | 
            +
                            self.spatial_cost_mask * cost_mask 
         | 
| 381 | 
            +
                            + self.spatial_cost_class * cost_class 
         | 
| 382 | 
            +
                            + self.spatial_cost_dice * cost_dice
         | 
| 383 | 
            +
                        )
         | 
| 384 | 
            +
                        C = C.reshape(num_queries, -1).cpu()
         | 
| 385 | 
            +
                        if C.isnan().any():
         | 
| 386 | 
            +
                            C[C.isnan()] = 1e6 ### temporary fix
         | 
| 387 | 
            +
                            warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
         | 
| 388 | 
            +
                            raise
         | 
| 389 | 
            +
                        indices.append(linear_sum_assignment(C))
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                    return [
         | 
| 392 | 
            +
                        (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
         | 
| 393 | 
            +
                        for i, j in indices
         | 
| 394 | 
            +
                    ]
         | 
| 395 | 
            +
             | 
| 396 | 
            +
                @torch.no_grad()
         | 
| 397 | 
            +
                def spatial_forward_pn(self, outputs, targets, extra):
         | 
| 398 | 
            +
                    """More memory-friendly matching"""
         | 
| 399 | 
            +
                    bs, num_queries = outputs["pred_smasks"].shape[:2]
         | 
| 400 | 
            +
                    
         | 
| 401 | 
            +
                    if bs == 0 or len(targets) == 0:
         | 
| 402 | 
            +
                        return None
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                    fp_mask = extra['false_positive_mask']
         | 
| 405 | 
            +
                    gt_mask = torch.stack([targets[b]["gt_spatial_masks"] for b in range(bs)])
         | 
| 406 | 
            +
             | 
| 407 | 
            +
                    indices = []
         | 
| 408 | 
            +
                    # Iterate through batch size
         | 
| 409 | 
            +
                    for b in range(bs):
         | 
| 410 | 
            +
                        out_prob = outputs["pred_neg_logits"][b]
         | 
| 411 | 
            +
                        # Compute the classification cost. Contrary to the loss, we don't use the NLL,
         | 
| 412 | 
            +
                        # but approximate it in 1 - proba[target class].
         | 
| 413 | 
            +
                        # The 1 is a constant that doesn't change the matching, it can be ommitted.
         | 
| 414 | 
            +
                        cost_class = -out_prob.softmax(dim=0)
         | 
| 415 | 
            +
             | 
| 416 | 
            +
                        out_mask = outputs["pred_smasks"][b]  # [num_queries, H_pred, W_pred]
         | 
| 417 | 
            +
                        tgt_mask = fp_mask[b].to(out_mask)
         | 
| 418 | 
            +
                        ign_mask = (gt_mask[b] | fp_mask[b]).to(out_mask)
         | 
| 419 | 
            +
             | 
| 420 | 
            +
                        out_mask = out_mask[:, None]
         | 
| 421 | 
            +
                        tgt_mask = tgt_mask[:, None]
         | 
| 422 | 
            +
                        ign_mask = ign_mask[:, None]
         | 
| 423 | 
            +
             | 
| 424 | 
            +
                        # all masks share the same set of points for efficient matching!
         | 
| 425 | 
            +
                        point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)
         | 
| 426 | 
            +
             | 
| 427 | 
            +
                        # get gt labels
         | 
| 428 | 
            +
                        tgt_mask = point_sample(
         | 
| 429 | 
            +
                            tgt_mask,
         | 
| 430 | 
            +
                            point_coords.repeat(tgt_mask.shape[0], 1, 1),
         | 
| 431 | 
            +
                            align_corners=False,
         | 
| 432 | 
            +
                        ).squeeze(1)
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                        out_mask = point_sample(
         | 
| 435 | 
            +
                            out_mask,
         | 
| 436 | 
            +
                            point_coords.repeat(out_mask.shape[0], 1, 1),
         | 
| 437 | 
            +
                            align_corners=False,
         | 
| 438 | 
            +
                        ).squeeze(1)
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                        ign_mask = point_sample(
         | 
| 441 | 
            +
                            ign_mask,
         | 
| 442 | 
            +
                            point_coords.repeat(ign_mask.shape[0], 1, 1),
         | 
| 443 | 
            +
                            align_corners=False,
         | 
| 444 | 
            +
                        ).squeeze(1)
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                        with autocast(enabled=False):
         | 
| 447 | 
            +
                            out_mask = out_mask.float()
         | 
| 448 | 
            +
                            tgt_mask = tgt_mask.float()
         | 
| 449 | 
            +
                            ign_mask = ign_mask.float()
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                            # Compute the focal loss between masks
         | 
| 452 | 
            +
                            cost_mask = batch_sigmoid_ce_loss_jit(out_mask*ign_mask, tgt_mask*ign_mask)
         | 
| 453 | 
            +
             | 
| 454 | 
            +
                            # Compute the dice loss betwen masks
         | 
| 455 | 
            +
                            cost_dice = batch_dice_loss_jit(out_mask*ign_mask, tgt_mask*ign_mask)
         | 
| 456 | 
            +
                        
         | 
| 457 | 
            +
                        # Final cost matrix
         | 
| 458 | 
            +
                        C = (
         | 
| 459 | 
            +
                            self.spatial_cost_mask * cost_mask 
         | 
| 460 | 
            +
                            + self.spatial_cost_class * cost_class 
         | 
| 461 | 
            +
                            + self.spatial_cost_dice * cost_dice
         | 
| 462 | 
            +
                        )
         | 
| 463 | 
            +
                        C = C.reshape(num_queries, -1).cpu()
         | 
| 464 | 
            +
                        if C.isnan().any():
         | 
| 465 | 
            +
                            C[C.isnan()] = 1e6 ### temporary fix
         | 
| 466 | 
            +
                            warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
         | 
| 467 | 
            +
                            raise
         | 
| 468 | 
            +
                        indices.append(linear_sum_assignment(C))
         | 
| 469 | 
            +
             | 
| 470 | 
            +
                    return [
         | 
| 471 | 
            +
                        (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
         | 
| 472 | 
            +
                        for i, j in indices
         | 
| 473 | 
            +
                    ]
         | 
| 474 | 
            +
             | 
| 475 | 
            +
                @torch.no_grad()
         | 
| 476 | 
            +
                def caption_forward_womask(self, outputs, targets, extra):
         | 
| 477 | 
            +
                    """More memory-friendly matching"""
         | 
| 478 | 
            +
                    bs, _ = outputs["pred_logits"].shape[:2]
         | 
| 479 | 
            +
             | 
| 480 | 
            +
                    if bs == 0 or len(targets) == 0:
         | 
| 481 | 
            +
                        return None
         | 
| 482 | 
            +
             | 
| 483 | 
            +
                    indices = []
         | 
| 484 | 
            +
                    t_emb = torch.cat([t['captions'] for t in targets])
         | 
| 485 | 
            +
                    v_emb = outputs['unmatched_pred_captions']
         | 
| 486 | 
            +
                    caption_target_count = np.cumsum([0] + [len(t['captions']) for t in targets])
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                    # Iterate through batch size
         | 
| 489 | 
            +
                    for b in range(bs):
         | 
| 490 | 
            +
                        v_emb[b] = v_emb[b] / (v_emb[b].norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 491 | 
            +
                        num_queries = len(v_emb[b])
         | 
| 492 | 
            +
                        out_prob = vl_similarity(v_emb[b][None,], t_emb, temperature=extra['temperature']).softmax(-1)[0]
         | 
| 493 | 
            +
                        tgt_ids = [idx for idx in range(caption_target_count[b], caption_target_count[b+1])]
         | 
| 494 | 
            +
             | 
| 495 | 
            +
                        # Compute the classification cost. Contrary to the loss, we don't use the NLL,
         | 
| 496 | 
            +
                        # but approximate it in 1 - proba[target class].
         | 
| 497 | 
            +
                        # The 1 is a constant that doesn't change the matching, it can be ommitted.
         | 
| 498 | 
            +
                        cost_class = -out_prob[:, tgt_ids]
         | 
| 499 | 
            +
             | 
| 500 | 
            +
                        # Final cost matrix
         | 
| 501 | 
            +
                        C = (self.cost_class * cost_class)
         | 
| 502 | 
            +
                        C = C.reshape(num_queries, -1).cpu()
         | 
| 503 | 
            +
                        if C.isnan().any():
         | 
| 504 | 
            +
                            C[C.isnan()] = 1e6 ### temporary fix
         | 
| 505 | 
            +
                            warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
         | 
| 506 | 
            +
                            raise
         | 
| 507 | 
            +
                        indices.append(linear_sum_assignment(C))
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                    return [
         | 
| 510 | 
            +
                        (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
         | 
| 511 | 
            +
                        for i, j in indices
         | 
| 512 | 
            +
                    ]
         | 
| 513 | 
            +
             | 
| 514 | 
            +
                @torch.no_grad()
         | 
| 515 | 
            +
                def caption_forward_wmask(self, outputs, targets, extra):
         | 
| 516 | 
            +
                    """More memory-friendly matching"""
         | 
| 517 | 
            +
                    bs, _ = outputs["pred_logits"].shape[:2]
         | 
| 518 | 
            +
             | 
| 519 | 
            +
                    if bs == 0 or len(targets) == 0:
         | 
| 520 | 
            +
                        return None
         | 
| 521 | 
            +
             | 
| 522 | 
            +
                    indices = []
         | 
| 523 | 
            +
                    t_emb = torch.cat([t['captions'] for t in targets])
         | 
| 524 | 
            +
                    v_emb = outputs['unmatched_pred_captions']
         | 
| 525 | 
            +
                    caption_target_count = np.cumsum([0] + [len(t['captions']) for t in targets])
         | 
| 526 | 
            +
                    
         | 
| 527 | 
            +
                    # Iterate through batch size
         | 
| 528 | 
            +
                    for b in range(bs):
         | 
| 529 | 
            +
                        v_emb[b] = v_emb[b] / (v_emb[b].norm(dim=-1, keepdim=True) + 1e-7)
         | 
| 530 | 
            +
                        num_queries = len(v_emb[b])
         | 
| 531 | 
            +
                        
         | 
| 532 | 
            +
                        out_prob = vl_similarity(v_emb[b][None,], t_emb, temperature=extra['temperature']).softmax(-1)[0]
         | 
| 533 | 
            +
                        tgt_ids = [idx for idx in range(caption_target_count[b], caption_target_count[b+1])]
         | 
| 534 | 
            +
             | 
| 535 | 
            +
                        # Compute the classification cost. Contrary to the loss, we don't use the NLL,
         | 
| 536 | 
            +
                        # but approximate it in 1 - proba[target class].
         | 
| 537 | 
            +
                        # The 1 is a constant that doesn't change the matching, it can be ommitted.
         | 
| 538 | 
            +
                        cost_class = -out_prob[:, tgt_ids]
         | 
| 539 | 
            +
             | 
| 540 | 
            +
                        out_mask = outputs["pred_masks"][b]  # [num_queries, H_pred, W_pred]
         | 
| 541 | 
            +
                        # gt masks are already padded when preparing target
         | 
| 542 | 
            +
                        tgt_mask = targets[b]["masks"].to(out_mask)
         | 
| 543 | 
            +
                        
         | 
| 544 | 
            +
                        out_mask = out_mask[:, None]
         | 
| 545 | 
            +
                        tgt_mask = tgt_mask[:, None]
         | 
| 546 | 
            +
                        # all masks share the same set of points for efficient matching!
         | 
| 547 | 
            +
                        point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)
         | 
| 548 | 
            +
                        # get gt labels
         | 
| 549 | 
            +
                        tgt_mask = point_sample(
         | 
| 550 | 
            +
                            tgt_mask,
         | 
| 551 | 
            +
                            point_coords.repeat(tgt_mask.shape[0], 1, 1),
         | 
| 552 | 
            +
                            align_corners=False,
         | 
| 553 | 
            +
                        ).squeeze(1)
         | 
| 554 | 
            +
             | 
| 555 | 
            +
                        out_mask = point_sample(
         | 
| 556 | 
            +
                            out_mask,
         | 
| 557 | 
            +
                            point_coords.repeat(out_mask.shape[0], 1, 1),
         | 
| 558 | 
            +
                            align_corners=False,
         | 
| 559 | 
            +
                        ).squeeze(1)
         | 
| 560 | 
            +
             | 
| 561 | 
            +
                        with autocast(enabled=False):
         | 
| 562 | 
            +
                            out_mask = out_mask.float()
         | 
| 563 | 
            +
                            tgt_mask = tgt_mask.float()
         | 
| 564 | 
            +
                            # Compute the focal loss between masks
         | 
| 565 | 
            +
                            cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask)
         | 
| 566 | 
            +
             | 
| 567 | 
            +
                            # Compute the dice loss betwen masks
         | 
| 568 | 
            +
                            cost_dice = batch_dice_loss_jit(out_mask, tgt_mask)
         | 
| 569 | 
            +
             | 
| 570 | 
            +
                        # Final cost matrix
         | 
| 571 | 
            +
                        C = (
         | 
| 572 | 
            +
                            self.cost_mask * cost_mask
         | 
| 573 | 
            +
                            + self.cost_class * cost_class
         | 
| 574 | 
            +
                            + self.cost_dice * cost_dice
         | 
| 575 | 
            +
                        )
         | 
| 576 | 
            +
                        C = C.reshape(num_queries, -1).cpu()
         | 
| 577 | 
            +
                        if C.isnan().any():
         | 
| 578 | 
            +
                            C[C.isnan()] = 1e6 ### temporary fix
         | 
| 579 | 
            +
                            warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
         | 
| 580 | 
            +
                            raise 
         | 
| 581 | 
            +
                        indices.append(linear_sum_assignment(C))
         | 
| 582 | 
            +
             | 
| 583 | 
            +
                    return [
         | 
| 584 | 
            +
                        (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
         | 
| 585 | 
            +
                        for i, j in indices
         | 
| 586 | 
            +
                    ]
         | 
| 587 | 
            +
             | 
| 588 | 
            +
                @torch.no_grad()
         | 
| 589 | 
            +
                def forward(self, outputs, targets, mode='default', extra={}):
         | 
| 590 | 
            +
                    """Performs the matching
         | 
| 591 | 
            +
             | 
| 592 | 
            +
                    Params:
         | 
| 593 | 
            +
                        outputs: This is a dict that contains at least these entries:
         | 
| 594 | 
            +
                             "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
         | 
| 595 | 
            +
                             "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks
         | 
| 596 | 
            +
             | 
| 597 | 
            +
                        targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
         | 
| 598 | 
            +
                             "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
         | 
| 599 | 
            +
                                       objects in the target) containing the class labels
         | 
| 600 | 
            +
                             "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks
         | 
| 601 | 
            +
             | 
| 602 | 
            +
                    Returns:
         | 
| 603 | 
            +
                        A list of size batch_size, containing tuples of (index_i, index_j) where:
         | 
| 604 | 
            +
                            - index_i is the indices of the selected predictions (in order)
         | 
| 605 | 
            +
                            - index_j is the indices of the corresponding selected targets (in order)
         | 
| 606 | 
            +
                        For each batch element, it holds:
         | 
| 607 | 
            +
                            len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
         | 
| 608 | 
            +
                    """
         | 
| 609 | 
            +
                    if mode == 'default':
         | 
| 610 | 
            +
                        return self.memory_efficient_forward(outputs, targets)
         | 
| 611 | 
            +
                    elif mode == 'grounding':
         | 
| 612 | 
            +
                        return self.grounding_forward(outputs, targets, extra)
         | 
| 613 | 
            +
                    elif mode == 'spatial':
         | 
| 614 | 
            +
                        return self.spatial_forward(outputs, targets, extra)
         | 
| 615 | 
            +
                    elif mode == 'spatial_pn':
         | 
| 616 | 
            +
                        return self.spatial_forward_pn(outputs, targets, extra)            
         | 
| 617 | 
            +
                    elif mode == 'caption_womask':
         | 
| 618 | 
            +
                        return self.caption_forward_womask(outputs, targets, extra)
         | 
| 619 | 
            +
                    elif mode == 'caption_wmask':
         | 
| 620 | 
            +
                        return self.caption_forward_wmask(outputs, targets, extra)
         | 
| 621 | 
            +
                    else:
         | 
| 622 | 
            +
                        assert False, "Mode {} is not supported.".format(mode)
         | 
| 623 | 
            +
             | 
| 624 | 
            +
                def __repr__(self, _repr_indent=4):
         | 
| 625 | 
            +
                    head = "Matcher " + self.__class__.__name__
         | 
| 626 | 
            +
                    body = [
         | 
| 627 | 
            +
                        "cost_class: {}".format(self.cost_class),
         | 
| 628 | 
            +
                        "cost_mask: {}".format(self.cost_mask),
         | 
| 629 | 
            +
                        "cost_dice: {}".format(self.cost_dice),
         | 
| 630 | 
            +
                    ]
         | 
| 631 | 
            +
                    lines = [head] + [" " * _repr_indent + line for line in body]
         | 
| 632 | 
            +
                    return "\n".join(lines)
         | 
