Upload model
Browse files- README.md +199 -0
- config.json +41 -0
- configuration.py +40 -0
- model.safetensors +3 -0
- modeling.py +98 -0
- unet.py +243 -0
    	
        README.md
    ADDED
    
    | @@ -0,0 +1,199 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            library_name: transformers
         | 
| 3 | 
            +
            tags: []
         | 
| 4 | 
            +
            ---
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # Model Card for Model ID
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            <!-- Provide a quick summary of what the model is/does. -->
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            ## Model Details
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            ### Model Description
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            <!-- Provide a longer summary of what this model is. -->
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            - **Developed by:** [More Information Needed]
         | 
| 21 | 
            +
            - **Funded by [optional]:** [More Information Needed]
         | 
| 22 | 
            +
            - **Shared by [optional]:** [More Information Needed]
         | 
| 23 | 
            +
            - **Model type:** [More Information Needed]
         | 
| 24 | 
            +
            - **Language(s) (NLP):** [More Information Needed]
         | 
| 25 | 
            +
            - **License:** [More Information Needed]
         | 
| 26 | 
            +
            - **Finetuned from model [optional]:** [More Information Needed]
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            ### Model Sources [optional]
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            <!-- Provide the basic links for the model. -->
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            - **Repository:** [More Information Needed]
         | 
| 33 | 
            +
            - **Paper [optional]:** [More Information Needed]
         | 
| 34 | 
            +
            - **Demo [optional]:** [More Information Needed]
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            ## Uses
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            ### Direct Use
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            [More Information Needed]
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            ### Downstream Use [optional]
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            [More Information Needed]
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            ### Out-of-Scope Use
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            [More Information Needed]
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            ## Bias, Risks, and Limitations
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            <!-- This section is meant to convey both technical and sociotechnical limitations. -->
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            [More Information Needed]
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            ### Recommendations
         | 
| 65 | 
            +
             | 
| 66 | 
            +
            <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            ## How to Get Started with the Model
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            Use the code below to get started with the model.
         | 
| 73 | 
            +
             | 
| 74 | 
            +
            [More Information Needed]
         | 
| 75 | 
            +
             | 
| 76 | 
            +
            ## Training Details
         | 
| 77 | 
            +
             | 
| 78 | 
            +
            ### Training Data
         | 
| 79 | 
            +
             | 
| 80 | 
            +
            <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
         | 
| 81 | 
            +
             | 
| 82 | 
            +
            [More Information Needed]
         | 
| 83 | 
            +
             | 
| 84 | 
            +
            ### Training Procedure
         | 
| 85 | 
            +
             | 
| 86 | 
            +
            <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
         | 
| 87 | 
            +
             | 
| 88 | 
            +
            #### Preprocessing [optional]
         | 
| 89 | 
            +
             | 
| 90 | 
            +
            [More Information Needed]
         | 
| 91 | 
            +
             | 
| 92 | 
            +
             | 
| 93 | 
            +
            #### Training Hyperparameters
         | 
| 94 | 
            +
             | 
| 95 | 
            +
            - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
         | 
| 96 | 
            +
             | 
| 97 | 
            +
            #### Speeds, Sizes, Times [optional]
         | 
| 98 | 
            +
             | 
| 99 | 
            +
            <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
         | 
| 100 | 
            +
             | 
| 101 | 
            +
            [More Information Needed]
         | 
| 102 | 
            +
             | 
| 103 | 
            +
            ## Evaluation
         | 
| 104 | 
            +
             | 
| 105 | 
            +
            <!-- This section describes the evaluation protocols and provides the results. -->
         | 
| 106 | 
            +
             | 
| 107 | 
            +
            ### Testing Data, Factors & Metrics
         | 
| 108 | 
            +
             | 
| 109 | 
            +
            #### Testing Data
         | 
| 110 | 
            +
             | 
| 111 | 
            +
            <!-- This should link to a Dataset Card if possible. -->
         | 
| 112 | 
            +
             | 
| 113 | 
            +
            [More Information Needed]
         | 
| 114 | 
            +
             | 
| 115 | 
            +
            #### Factors
         | 
| 116 | 
            +
             | 
| 117 | 
            +
            <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
         | 
| 118 | 
            +
             | 
| 119 | 
            +
            [More Information Needed]
         | 
| 120 | 
            +
             | 
| 121 | 
            +
            #### Metrics
         | 
| 122 | 
            +
             | 
| 123 | 
            +
            <!-- These are the evaluation metrics being used, ideally with a description of why. -->
         | 
| 124 | 
            +
             | 
| 125 | 
            +
            [More Information Needed]
         | 
| 126 | 
            +
             | 
| 127 | 
            +
            ### Results
         | 
| 128 | 
            +
             | 
| 129 | 
            +
            [More Information Needed]
         | 
| 130 | 
            +
             | 
| 131 | 
            +
            #### Summary
         | 
| 132 | 
            +
             | 
| 133 | 
            +
             | 
| 134 | 
            +
             | 
| 135 | 
            +
            ## Model Examination [optional]
         | 
| 136 | 
            +
             | 
| 137 | 
            +
            <!-- Relevant interpretability work for the model goes here -->
         | 
| 138 | 
            +
             | 
| 139 | 
            +
            [More Information Needed]
         | 
| 140 | 
            +
             | 
| 141 | 
            +
            ## Environmental Impact
         | 
| 142 | 
            +
             | 
| 143 | 
            +
            <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
         | 
| 144 | 
            +
             | 
| 145 | 
            +
            Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
         | 
| 146 | 
            +
             | 
| 147 | 
            +
            - **Hardware Type:** [More Information Needed]
         | 
| 148 | 
            +
            - **Hours used:** [More Information Needed]
         | 
| 149 | 
            +
            - **Cloud Provider:** [More Information Needed]
         | 
| 150 | 
            +
            - **Compute Region:** [More Information Needed]
         | 
| 151 | 
            +
            - **Carbon Emitted:** [More Information Needed]
         | 
| 152 | 
            +
             | 
| 153 | 
            +
            ## Technical Specifications [optional]
         | 
| 154 | 
            +
             | 
| 155 | 
            +
            ### Model Architecture and Objective
         | 
| 156 | 
            +
             | 
| 157 | 
            +
            [More Information Needed]
         | 
| 158 | 
            +
             | 
| 159 | 
            +
            ### Compute Infrastructure
         | 
| 160 | 
            +
             | 
| 161 | 
            +
            [More Information Needed]
         | 
| 162 | 
            +
             | 
| 163 | 
            +
            #### Hardware
         | 
| 164 | 
            +
             | 
| 165 | 
            +
            [More Information Needed]
         | 
| 166 | 
            +
             | 
| 167 | 
            +
            #### Software
         | 
| 168 | 
            +
             | 
| 169 | 
            +
            [More Information Needed]
         | 
| 170 | 
            +
             | 
| 171 | 
            +
            ## Citation [optional]
         | 
| 172 | 
            +
             | 
| 173 | 
            +
            <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
         | 
| 174 | 
            +
             | 
| 175 | 
            +
            **BibTeX:**
         | 
| 176 | 
            +
             | 
| 177 | 
            +
            [More Information Needed]
         | 
| 178 | 
            +
             | 
| 179 | 
            +
            **APA:**
         | 
| 180 | 
            +
             | 
| 181 | 
            +
            [More Information Needed]
         | 
| 182 | 
            +
             | 
| 183 | 
            +
            ## Glossary [optional]
         | 
| 184 | 
            +
             | 
| 185 | 
            +
            <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
         | 
| 186 | 
            +
             | 
| 187 | 
            +
            [More Information Needed]
         | 
| 188 | 
            +
             | 
| 189 | 
            +
            ## More Information [optional]
         | 
| 190 | 
            +
             | 
| 191 | 
            +
            [More Information Needed]
         | 
| 192 | 
            +
             | 
| 193 | 
            +
            ## Model Card Authors [optional]
         | 
| 194 | 
            +
             | 
| 195 | 
            +
            [More Information Needed]
         | 
| 196 | 
            +
             | 
| 197 | 
            +
            ## Model Card Contact
         | 
| 198 | 
            +
             | 
| 199 | 
            +
            [More Information Needed]
         | 
    	
        config.json
    ADDED
    
    | @@ -0,0 +1,41 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "architectures": [
         | 
| 3 | 
            +
                "CXRModel"
         | 
| 4 | 
            +
              ],
         | 
| 5 | 
            +
              "auto_map": {
         | 
| 6 | 
            +
                "AutoConfig": "configuration.CXRConfig",
         | 
| 7 | 
            +
                "AutoModel": "modeling.CXRModel"
         | 
| 8 | 
            +
              },
         | 
| 9 | 
            +
              "backbone": "tf_efficientnetv2_s",
         | 
| 10 | 
            +
              "cls_dropout": 0.1,
         | 
| 11 | 
            +
              "cls_num_classes": 5,
         | 
| 12 | 
            +
              "decoder_attention_type": null,
         | 
| 13 | 
            +
              "decoder_center_block": false,
         | 
| 14 | 
            +
              "decoder_channels": [
         | 
| 15 | 
            +
                256,
         | 
| 16 | 
            +
                128,
         | 
| 17 | 
            +
                64,
         | 
| 18 | 
            +
                32,
         | 
| 19 | 
            +
                16
         | 
| 20 | 
            +
              ],
         | 
| 21 | 
            +
              "decoder_n_blocks": 5,
         | 
| 22 | 
            +
              "decoder_norm_layer": "bn",
         | 
| 23 | 
            +
              "encoder_channels": [
         | 
| 24 | 
            +
                24,
         | 
| 25 | 
            +
                48,
         | 
| 26 | 
            +
                64,
         | 
| 27 | 
            +
                160,
         | 
| 28 | 
            +
                256
         | 
| 29 | 
            +
              ],
         | 
| 30 | 
            +
              "feature_dim": 256,
         | 
| 31 | 
            +
              "img_size": [
         | 
| 32 | 
            +
                320,
         | 
| 33 | 
            +
                320
         | 
| 34 | 
            +
              ],
         | 
| 35 | 
            +
              "in_chans": 1,
         | 
| 36 | 
            +
              "model_type": "cxr_basic",
         | 
| 37 | 
            +
              "seg_dropout": 0.1,
         | 
| 38 | 
            +
              "seg_num_classes": 4,
         | 
| 39 | 
            +
              "torch_dtype": "float32",
         | 
| 40 | 
            +
              "transformers_version": "4.47.0"
         | 
| 41 | 
            +
            }
         | 
    	
        configuration.py
    ADDED
    
    | @@ -0,0 +1,40 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from transformers import PretrainedConfig
         | 
| 2 | 
            +
            from typing import List, Optional, Tuple
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            class CXRConfig(PretrainedConfig):
         | 
| 6 | 
            +
                model_type = "cxr_basic"
         | 
| 7 | 
            +
             | 
| 8 | 
            +
                def __init__(
         | 
| 9 | 
            +
                    self,
         | 
| 10 | 
            +
                    backbone: str = "tf_efficientnetv2_s",
         | 
| 11 | 
            +
                    feature_dim: int = 256,
         | 
| 12 | 
            +
                    seg_dropout: float = 0.1,
         | 
| 13 | 
            +
                    cls_dropout: float = 0.1,
         | 
| 14 | 
            +
                    seg_num_classes: int = 4,
         | 
| 15 | 
            +
                    cls_num_classes: int = 5,
         | 
| 16 | 
            +
                    in_chans: int = 1,
         | 
| 17 | 
            +
                    img_size: Tuple[int, int] = (320, 320),  # height, width
         | 
| 18 | 
            +
                    decoder_n_blocks: int = 5,
         | 
| 19 | 
            +
                    decoder_channels: List[int] = [256, 128, 64, 32, 16],
         | 
| 20 | 
            +
                    encoder_channels: List[int] = [24, 48, 64, 160, 256],
         | 
| 21 | 
            +
                    decoder_center_block: bool = False,
         | 
| 22 | 
            +
                    decoder_norm_layer: str = "bn",
         | 
| 23 | 
            +
                    decoder_attention_type: Optional[str] = None,
         | 
| 24 | 
            +
                    **kwargs,
         | 
| 25 | 
            +
                ):
         | 
| 26 | 
            +
                    self.backbone = backbone
         | 
| 27 | 
            +
                    self.feature_dim = feature_dim
         | 
| 28 | 
            +
                    self.seg_dropout = seg_dropout
         | 
| 29 | 
            +
                    self.cls_dropout = cls_dropout
         | 
| 30 | 
            +
                    self.seg_num_classes = seg_num_classes
         | 
| 31 | 
            +
                    self.cls_num_classes = cls_num_classes
         | 
| 32 | 
            +
                    self.in_chans = in_chans
         | 
| 33 | 
            +
                    self.img_size = img_size
         | 
| 34 | 
            +
                    self.decoder_n_blocks = decoder_n_blocks
         | 
| 35 | 
            +
                    self.decoder_channels = decoder_channels
         | 
| 36 | 
            +
                    self.encoder_channels = encoder_channels
         | 
| 37 | 
            +
                    self.decoder_center_block = decoder_center_block
         | 
| 38 | 
            +
                    self.decoder_norm_layer = decoder_norm_layer
         | 
| 39 | 
            +
                    self.decoder_attention_type = decoder_attention_type
         | 
| 40 | 
            +
                    super().__init__(**kwargs)
         | 
    	
        model.safetensors
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:f9aa6275e28f1f13f4977a919e12d46a0cac4aa64a7c51b49afd63e9978087e4
         | 
| 3 | 
            +
            size 89078700
         | 
    	
        modeling.py
    ADDED
    
    | @@ -0,0 +1,98 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import albumentations as A
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.nn as nn
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from numpy.typing import NDArray
         | 
| 6 | 
            +
            from transformers import PreTrainedModel
         | 
| 7 | 
            +
            from timm import create_model
         | 
| 8 | 
            +
            from typing import Optional
         | 
| 9 | 
            +
            from .configuration import CXRConfig
         | 
| 10 | 
            +
            from .unet import UnetDecoder, SegmentationHead
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            _PYDICOM_AVAILABLE = False
         | 
| 13 | 
            +
            try:
         | 
| 14 | 
            +
                from pydicom import dcmread
         | 
| 15 | 
            +
                from pydicom.pixels import apply_voi_lut
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                _PYDICOM_AVAILABLE = True
         | 
| 18 | 
            +
            except ModuleNotFoundError:
         | 
| 19 | 
            +
                pass
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class CXRModel(PreTrainedModel):
         | 
| 23 | 
            +
                config_class = CXRConfig
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def __init__(self, config):
         | 
| 26 | 
            +
                    super().__init__(config)
         | 
| 27 | 
            +
                    self.encoder = create_model(
         | 
| 28 | 
            +
                        model_name=config.backbone,
         | 
| 29 | 
            +
                        features_only=True,
         | 
| 30 | 
            +
                        pretrained=False,
         | 
| 31 | 
            +
                        in_chans=config.in_chans,
         | 
| 32 | 
            +
                    )
         | 
| 33 | 
            +
                    self.decoder = UnetDecoder(
         | 
| 34 | 
            +
                        decoder_n_blocks=config.decoder_n_blocks,
         | 
| 35 | 
            +
                        decoder_channels=config.decoder_channels,
         | 
| 36 | 
            +
                        encoder_channels=config.encoder_channels,
         | 
| 37 | 
            +
                        decoder_center_block=config.decoder_center_block,
         | 
| 38 | 
            +
                        decoder_norm_layer=config.decoder_norm_layer,
         | 
| 39 | 
            +
                        decoder_attention_type=config.decoder_attention_type,
         | 
| 40 | 
            +
                    )
         | 
| 41 | 
            +
                    self.img_size = config.img_size
         | 
| 42 | 
            +
                    self.segmentation_head = SegmentationHead(
         | 
| 43 | 
            +
                        in_channels=config.decoder_channels[-1],
         | 
| 44 | 
            +
                        out_channels=config.seg_num_classes,
         | 
| 45 | 
            +
                        size=self.img_size,
         | 
| 46 | 
            +
                    )
         | 
| 47 | 
            +
                    self.pooling = nn.AdaptiveAvgPool2d(1)
         | 
| 48 | 
            +
                    self.dropout = nn.Dropout(p=config.cls_dropout)
         | 
| 49 | 
            +
                    self.classifier = nn.Linear(config.feature_dim, config.cls_num_classes)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                def normalize(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 52 | 
            +
                    # [0, 255] -> [-1, 1]
         | 
| 53 | 
            +
                    mini, maxi = 0.0, 255.0
         | 
| 54 | 
            +
                    x = (x - mini) / (maxi - mini)
         | 
| 55 | 
            +
                    x = (x - 0.5) * 2.0
         | 
| 56 | 
            +
                    return x
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                @staticmethod
         | 
| 59 | 
            +
                def load_image_from_dicom(path: str) -> Optional[NDArray]:
         | 
| 60 | 
            +
                    if not _PYDICOM_AVAILABLE:
         | 
| 61 | 
            +
                        print("`pydicom` is not installed, returning None ...")
         | 
| 62 | 
            +
                        return None
         | 
| 63 | 
            +
                    dicom = dcmread(path)
         | 
| 64 | 
            +
                    arr = apply_voi_lut(dicom.pixel_array, dicom)
         | 
| 65 | 
            +
                    if dicom.PhotometricInterpretation == "MONOCHROME1":
         | 
| 66 | 
            +
                        # invert image if needed
         | 
| 67 | 
            +
                        arr = arr.max() - arr
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    arr = arr - arr.min()
         | 
| 70 | 
            +
                    arr = arr / arr.max()
         | 
| 71 | 
            +
                    arr = (arr * 255).astype("uint8")
         | 
| 72 | 
            +
                    return arr
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                def preprocess(self, x: NDArray) -> NDArray:
         | 
| 75 | 
            +
                    x = A.Resize(self.img_size[0], self.img_size[1], p=1)(image=x)["image"]
         | 
| 76 | 
            +
                    return x
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def forward(self, x: torch.Tensor, return_logits: bool = False) -> torch.Tensor:
         | 
| 79 | 
            +
                    x = self.normalize(x)
         | 
| 80 | 
            +
                    features = self.encoder(x)
         | 
| 81 | 
            +
                    decoder_output = self.decoder(features)
         | 
| 82 | 
            +
                    logits = self.segmentation_head(decoder_output[-1])
         | 
| 83 | 
            +
                    b, n = features[-1].shape[:2]
         | 
| 84 | 
            +
                    features = self.pooling(features[-1]).reshape(b, n)
         | 
| 85 | 
            +
                    features = self.dropout(features)
         | 
| 86 | 
            +
                    cls_logits = self.classifier(features)
         | 
| 87 | 
            +
                    out = {
         | 
| 88 | 
            +
                        "mask": logits,
         | 
| 89 | 
            +
                        "age": cls_logits[:, 0].unsqueeze(1),
         | 
| 90 | 
            +
                        "view": cls_logits[:, 1:4],
         | 
| 91 | 
            +
                        "female": cls_logits[:, 4].unsqueeze(1),
         | 
| 92 | 
            +
                    }
         | 
| 93 | 
            +
                    if return_logits:
         | 
| 94 | 
            +
                        return out
         | 
| 95 | 
            +
                    out["mask"] = out["mask"].softmax(1)
         | 
| 96 | 
            +
                    out["view"] = out["view"].softmax(1)
         | 
| 97 | 
            +
                    out["female"] = out["female"].sigmoid()
         | 
| 98 | 
            +
                    return out
         | 
    	
        unet.py
    ADDED
    
    | @@ -0,0 +1,243 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from functools import partial
         | 
| 6 | 
            +
            from typing import List, Optional
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            class Conv2dAct(nn.Sequential):
         | 
| 10 | 
            +
                def __init__(
         | 
| 11 | 
            +
                    self,
         | 
| 12 | 
            +
                    in_channels: int,
         | 
| 13 | 
            +
                    out_channels: int,
         | 
| 14 | 
            +
                    kernel_size: int,
         | 
| 15 | 
            +
                    padding: int = 0,
         | 
| 16 | 
            +
                    stride: int = 1,
         | 
| 17 | 
            +
                    norm_layer: str = "bn",
         | 
| 18 | 
            +
                    num_groups: int = 32,  # for GroupNorm,
         | 
| 19 | 
            +
                    activation: str = "ReLU",
         | 
| 20 | 
            +
                    inplace: bool = True,  # for activation
         | 
| 21 | 
            +
                ):
         | 
| 22 | 
            +
                    if norm_layer == "bn":
         | 
| 23 | 
            +
                        NormLayer = nn.BatchNorm2d
         | 
| 24 | 
            +
                    elif norm_layer == "gn":
         | 
| 25 | 
            +
                        NormLayer = partial(nn.GroupNorm, num_groups=num_groups)
         | 
| 26 | 
            +
                    else:
         | 
| 27 | 
            +
                        raise Exception(
         | 
| 28 | 
            +
                            f"`norm_layer` must be one of [`bn`, `gn`], got `{norm_layer}`"
         | 
| 29 | 
            +
                        )
         | 
| 30 | 
            +
                    super().__init__()
         | 
| 31 | 
            +
                    self.conv = nn.Conv2d(
         | 
| 32 | 
            +
                        in_channels,
         | 
| 33 | 
            +
                        out_channels,
         | 
| 34 | 
            +
                        kernel_size=kernel_size,
         | 
| 35 | 
            +
                        stride=stride,
         | 
| 36 | 
            +
                        padding=padding,
         | 
| 37 | 
            +
                        bias=False,
         | 
| 38 | 
            +
                    )
         | 
| 39 | 
            +
                    self.norm = NormLayer(out_channels)
         | 
| 40 | 
            +
                    self.act = getattr(nn, activation)(inplace=inplace)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 43 | 
            +
                    return self.act(self.norm(self.conv(x)))
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            class SCSEModule(nn.Module):
         | 
| 47 | 
            +
                def __init__(
         | 
| 48 | 
            +
                    self,
         | 
| 49 | 
            +
                    in_channels: int,
         | 
| 50 | 
            +
                    reduction: int = 16,
         | 
| 51 | 
            +
                    activation: str = "ReLU",
         | 
| 52 | 
            +
                    inplace: bool = False,
         | 
| 53 | 
            +
                ):
         | 
| 54 | 
            +
                    super().__init__()
         | 
| 55 | 
            +
                    self.cSE = nn.Sequential(
         | 
| 56 | 
            +
                        nn.AdaptiveAvgPool2d(1),
         | 
| 57 | 
            +
                        nn.Conv2d(in_channels, in_channels // reduction, 1),
         | 
| 58 | 
            +
                        getattr(nn, activation)(inplace=inplace),
         | 
| 59 | 
            +
                        nn.Conv2d(in_channels // reduction, in_channels, 1),
         | 
| 60 | 
            +
                    )
         | 
| 61 | 
            +
                    self.sSE = nn.Conv2d(in_channels, 1, 1)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 64 | 
            +
                    return x * self.cSE(x).sigmoid() + x * self.sSE(x).sigmoid()
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            class Attention(nn.Module):
         | 
| 68 | 
            +
                def __init__(self, name: str, **params):
         | 
| 69 | 
            +
                    super().__init__()
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    if name is None:
         | 
| 72 | 
            +
                        self.attention = nn.Identity(**params)
         | 
| 73 | 
            +
                    elif name == "scse":
         | 
| 74 | 
            +
                        self.attention = SCSEModule(**params)
         | 
| 75 | 
            +
                    else:
         | 
| 76 | 
            +
                        raise ValueError("Attention {} is not implemented".format(name))
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 79 | 
            +
                    return self.attention(x)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
             | 
| 82 | 
            +
            class DecoderBlock(nn.Module):
         | 
| 83 | 
            +
                def __init__(
         | 
| 84 | 
            +
                    self,
         | 
| 85 | 
            +
                    in_channels: int,
         | 
| 86 | 
            +
                    skip_channels: int,
         | 
| 87 | 
            +
                    out_channels: int,
         | 
| 88 | 
            +
                    norm_layer: str = "bn",
         | 
| 89 | 
            +
                    activation: str = "ReLU",
         | 
| 90 | 
            +
                    attention_type: Optional[str] = None,
         | 
| 91 | 
            +
                ):
         | 
| 92 | 
            +
                    super().__init__()
         | 
| 93 | 
            +
                    self.conv1 = Conv2dAct(
         | 
| 94 | 
            +
                        in_channels + skip_channels,
         | 
| 95 | 
            +
                        out_channels,
         | 
| 96 | 
            +
                        kernel_size=3,
         | 
| 97 | 
            +
                        padding=1,
         | 
| 98 | 
            +
                        norm_layer=norm_layer,
         | 
| 99 | 
            +
                        activation=activation,
         | 
| 100 | 
            +
                    )
         | 
| 101 | 
            +
                    self.attention1 = Attention(
         | 
| 102 | 
            +
                        attention_type, in_channels=in_channels + skip_channels
         | 
| 103 | 
            +
                    )
         | 
| 104 | 
            +
                    self.conv2 = Conv2dAct(
         | 
| 105 | 
            +
                        out_channels,
         | 
| 106 | 
            +
                        out_channels,
         | 
| 107 | 
            +
                        kernel_size=3,
         | 
| 108 | 
            +
                        padding=1,
         | 
| 109 | 
            +
                        norm_layer=norm_layer,
         | 
| 110 | 
            +
                        activation=activation,
         | 
| 111 | 
            +
                    )
         | 
| 112 | 
            +
                    self.attention2 = Attention(attention_type, in_channels=out_channels)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                def forward(
         | 
| 115 | 
            +
                    self, x: torch.Tensor, skip: Optional[torch.Tensor] = None
         | 
| 116 | 
            +
                ) -> torch.Tensor:
         | 
| 117 | 
            +
                    if skip is not None:
         | 
| 118 | 
            +
                        h, w = skip.shape[2:]
         | 
| 119 | 
            +
                        x = F.interpolate(x, size=(h, w), mode="nearest")
         | 
| 120 | 
            +
                        x = torch.cat([x, skip], dim=1)
         | 
| 121 | 
            +
                        x = self.attention1(x)
         | 
| 122 | 
            +
                    else:
         | 
| 123 | 
            +
                        x = F.interpolate(x, scale_factor=(2, 2), mode="nearest")
         | 
| 124 | 
            +
                    x = self.conv1(x)
         | 
| 125 | 
            +
                    x = self.conv2(x)
         | 
| 126 | 
            +
                    x = self.attention2(x)
         | 
| 127 | 
            +
                    return x
         | 
| 128 | 
            +
             | 
| 129 | 
            +
             | 
| 130 | 
            +
            class CenterBlock(nn.Sequential):
         | 
| 131 | 
            +
                def __init__(
         | 
| 132 | 
            +
                    self,
         | 
| 133 | 
            +
                    in_channels: int,
         | 
| 134 | 
            +
                    out_channels: int,
         | 
| 135 | 
            +
                    norm_layer: str = "bn",
         | 
| 136 | 
            +
                    activation: str = "ReLU",
         | 
| 137 | 
            +
                ):
         | 
| 138 | 
            +
                    conv1 = Conv2dAct(
         | 
| 139 | 
            +
                        in_channels,
         | 
| 140 | 
            +
                        out_channels,
         | 
| 141 | 
            +
                        kernel_size=3,
         | 
| 142 | 
            +
                        padding=1,
         | 
| 143 | 
            +
                        norm_layer=norm_layer,
         | 
| 144 | 
            +
                        activation=activation,
         | 
| 145 | 
            +
                    )
         | 
| 146 | 
            +
                    conv2 = Conv2dAct(
         | 
| 147 | 
            +
                        out_channels,
         | 
| 148 | 
            +
                        out_channels,
         | 
| 149 | 
            +
                        kernel_size=3,
         | 
| 150 | 
            +
                        padding=1,
         | 
| 151 | 
            +
                        norm_layer=norm_layer,
         | 
| 152 | 
            +
                        activation=activation,
         | 
| 153 | 
            +
                    )
         | 
| 154 | 
            +
                    super().__init__(conv1, conv2)
         | 
| 155 | 
            +
             | 
| 156 | 
            +
             | 
| 157 | 
            +
            class UnetDecoder(nn.Module):
         | 
| 158 | 
            +
                def __init__(
         | 
| 159 | 
            +
                    self,
         | 
| 160 | 
            +
                    decoder_n_blocks: int,
         | 
| 161 | 
            +
                    decoder_channels: List[int],
         | 
| 162 | 
            +
                    encoder_channels: List[int],
         | 
| 163 | 
            +
                    decoder_center_block: bool = False,
         | 
| 164 | 
            +
                    decoder_norm_layer: str = "bn",
         | 
| 165 | 
            +
                    decoder_attention_type: Optional[str] = None,
         | 
| 166 | 
            +
                ):
         | 
| 167 | 
            +
                    super().__init__()
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    self.decoder_n_blocks = decoder_n_blocks
         | 
| 170 | 
            +
                    self.decoder_channels = decoder_channels
         | 
| 171 | 
            +
                    self.encoder_channels = encoder_channels
         | 
| 172 | 
            +
                    self.decoder_center_block = decoder_center_block
         | 
| 173 | 
            +
                    self.decoder_norm_layer = decoder_norm_layer
         | 
| 174 | 
            +
                    self.decoder_attention_type = decoder_attention_type
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    if self.decoder_n_blocks != len(self.decoder_channels):
         | 
| 177 | 
            +
                        raise ValueError(
         | 
| 178 | 
            +
                            "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
         | 
| 179 | 
            +
                                self.decoder_n_blocks, len(self.decoder_channels)
         | 
| 180 | 
            +
                            )
         | 
| 181 | 
            +
                        )
         | 
| 182 | 
            +
                    # reverse channels to start from head of encoder
         | 
| 183 | 
            +
                    encoder_channels = encoder_channels[::-1]
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                    # computing blocks input and output channels
         | 
| 186 | 
            +
                    head_channels = encoder_channels[0]
         | 
| 187 | 
            +
                    in_channels = [head_channels] + list(self.decoder_channels[:-1])
         | 
| 188 | 
            +
                    skip_channels = list(encoder_channels[1:]) + [0]
         | 
| 189 | 
            +
                    out_channels = self.decoder_channels
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    if self.decoder_center_block:
         | 
| 192 | 
            +
                        self.center = CenterBlock(
         | 
| 193 | 
            +
                            head_channels, head_channels, norm_layer=self.decoder_norm_layer
         | 
| 194 | 
            +
                        )
         | 
| 195 | 
            +
                    else:
         | 
| 196 | 
            +
                        self.center = nn.Identity()
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                    # combine decoder keyword arguments
         | 
| 199 | 
            +
                    kwargs = dict(
         | 
| 200 | 
            +
                        norm_layer=self.decoder_norm_layer,
         | 
| 201 | 
            +
                        attention_type=self.decoder_attention_type,
         | 
| 202 | 
            +
                    )
         | 
| 203 | 
            +
                    blocks = [
         | 
| 204 | 
            +
                        DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
         | 
| 205 | 
            +
                        for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
         | 
| 206 | 
            +
                    ]
         | 
| 207 | 
            +
                    self.blocks = nn.ModuleList(blocks)
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
         | 
| 210 | 
            +
                    features = features[::-1]  # reverse channels to start from head of encoder
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                    head = features[0]
         | 
| 213 | 
            +
                    skips = features[1:]
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    output = [self.center(head)]
         | 
| 216 | 
            +
                    for i, decoder_block in enumerate(self.blocks):
         | 
| 217 | 
            +
                        skip = skips[i] if i < len(skips) else None
         | 
| 218 | 
            +
                        output.append(decoder_block(output[-1], skip))
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    return output
         | 
| 221 | 
            +
             | 
| 222 | 
            +
             | 
| 223 | 
            +
            class SegmentationHead(nn.Module):
         | 
| 224 | 
            +
                def __init__(
         | 
| 225 | 
            +
                    self,
         | 
| 226 | 
            +
                    in_channels: int,
         | 
| 227 | 
            +
                    out_channels: int,
         | 
| 228 | 
            +
                    size: int,
         | 
| 229 | 
            +
                    kernel_size: int = 3,
         | 
| 230 | 
            +
                    dropout: float = 0.0,
         | 
| 231 | 
            +
                ):
         | 
| 232 | 
            +
                    super().__init__()
         | 
| 233 | 
            +
                    self.drop = nn.Dropout2d(p=dropout)
         | 
| 234 | 
            +
                    self.conv = nn.Conv2d(
         | 
| 235 | 
            +
                        in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2
         | 
| 236 | 
            +
                    )
         | 
| 237 | 
            +
                    if isinstance(size, (tuple, list)):
         | 
| 238 | 
            +
                        self.up = nn.Upsample(size=size, mode="bilinear")
         | 
| 239 | 
            +
                    else:
         | 
| 240 | 
            +
                        self.up = nn.Identity()
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 243 | 
            +
                    return self.up(self.conv(self.drop(x)))
         |