[Init] upload model
Browse files- .gitattributes +0 -34
- config.json +191 -0
- config.py +240 -0
- demo.py +143 -0
- flash_attention_class.py +74 -0
- internvideo2.py +779 -0
- internvideo2_clip_vision.py +553 -0
- mobile_clip.py +264 -0
- mobile_clip_transformer.py +449 -0
- model.safetensors +3 -0
- modeling_internvideo2encoder.py +152 -0
- pos_embed.py +299 -0
- test.ipynb +424 -0
    	
        .gitattributes
    CHANGED
    
    | @@ -1,35 +1 @@ | |
| 1 | 
            -
            *.7z filter=lfs diff=lfs merge=lfs -text
         | 
| 2 | 
            -
            *.arrow filter=lfs diff=lfs merge=lfs -text
         | 
| 3 | 
            -
            *.bin filter=lfs diff=lfs merge=lfs -text
         | 
| 4 | 
            -
            *.bz2 filter=lfs diff=lfs merge=lfs -text
         | 
| 5 | 
            -
            *.ckpt filter=lfs diff=lfs merge=lfs -text
         | 
| 6 | 
            -
            *.ftz filter=lfs diff=lfs merge=lfs -text
         | 
| 7 | 
            -
            *.gz filter=lfs diff=lfs merge=lfs -text
         | 
| 8 | 
            -
            *.h5 filter=lfs diff=lfs merge=lfs -text
         | 
| 9 | 
            -
            *.joblib filter=lfs diff=lfs merge=lfs -text
         | 
| 10 | 
            -
            *.lfs.* filter=lfs diff=lfs merge=lfs -text
         | 
| 11 | 
            -
            *.mlmodel filter=lfs diff=lfs merge=lfs -text
         | 
| 12 | 
            -
            *.model filter=lfs diff=lfs merge=lfs -text
         | 
| 13 | 
            -
            *.msgpack filter=lfs diff=lfs merge=lfs -text
         | 
| 14 | 
            -
            *.npy filter=lfs diff=lfs merge=lfs -text
         | 
| 15 | 
            -
            *.npz filter=lfs diff=lfs merge=lfs -text
         | 
| 16 | 
            -
            *.onnx filter=lfs diff=lfs merge=lfs -text
         | 
| 17 | 
            -
            *.ot filter=lfs diff=lfs merge=lfs -text
         | 
| 18 | 
            -
            *.parquet filter=lfs diff=lfs merge=lfs -text
         | 
| 19 | 
            -
            *.pb filter=lfs diff=lfs merge=lfs -text
         | 
| 20 | 
            -
            *.pickle filter=lfs diff=lfs merge=lfs -text
         | 
| 21 | 
            -
            *.pkl filter=lfs diff=lfs merge=lfs -text
         | 
| 22 | 
            -
            *.pt filter=lfs diff=lfs merge=lfs -text
         | 
| 23 | 
            -
            *.pth filter=lfs diff=lfs merge=lfs -text
         | 
| 24 | 
            -
            *.rar filter=lfs diff=lfs merge=lfs -text
         | 
| 25 | 
             
            *.safetensors filter=lfs diff=lfs merge=lfs -text
         | 
| 26 | 
            -
            saved_model/**/* filter=lfs diff=lfs merge=lfs -text
         | 
| 27 | 
            -
            *.tar.* filter=lfs diff=lfs merge=lfs -text
         | 
| 28 | 
            -
            *.tar filter=lfs diff=lfs merge=lfs -text
         | 
| 29 | 
            -
            *.tflite filter=lfs diff=lfs merge=lfs -text
         | 
| 30 | 
            -
            *.tgz filter=lfs diff=lfs merge=lfs -text
         | 
| 31 | 
            -
            *.wasm filter=lfs diff=lfs merge=lfs -text
         | 
| 32 | 
            -
            *.xz filter=lfs diff=lfs merge=lfs -text
         | 
| 33 | 
            -
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
            -
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
            -
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1 | 
             
            *.safetensors filter=lfs diff=lfs merge=lfs -text
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        config.json
    ADDED
    
    | @@ -0,0 +1,191 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "architectures": [
         | 
| 3 | 
            +
                "InternVideo2_CLIP_small"
         | 
| 4 | 
            +
              ],
         | 
| 5 | 
            +
              "auto_map": {
         | 
| 6 | 
            +
                "AutoConfig": "config.InternVideo2Config",
         | 
| 7 | 
            +
                "AutoModel": "modeling_internvideo2encoder.InternVideo2_CLIP_small"
         | 
| 8 | 
            +
              },
         | 
| 9 | 
            +
              "auto_resume": false,
         | 
| 10 | 
            +
              "batch_size": 64,
         | 
| 11 | 
            +
              "batch_size_test": 4,
         | 
| 12 | 
            +
              "best_key": [
         | 
| 13 | 
            +
                "msrvtt_1k_test_match",
         | 
| 14 | 
            +
                "t2v_r1"
         | 
| 15 | 
            +
              ],
         | 
| 16 | 
            +
              "compile_model": false,
         | 
| 17 | 
            +
              "criterion": {
         | 
| 18 | 
            +
                "clip_loss_ratio": [
         | 
| 19 | 
            +
                  1.0,
         | 
| 20 | 
            +
                  1.0
         | 
| 21 | 
            +
                ],
         | 
| 22 | 
            +
                "distill_final_features": true,
         | 
| 23 | 
            +
                "loss_weight": {
         | 
| 24 | 
            +
                  "mlm": 1.0,
         | 
| 25 | 
            +
                  "mvm": 0.0,
         | 
| 26 | 
            +
                  "uta": 0.0,
         | 
| 27 | 
            +
                  "vtc": 1.0,
         | 
| 28 | 
            +
                  "vtm": 1.0
         | 
| 29 | 
            +
                },
         | 
| 30 | 
            +
                "mlm_masking_prob": 0.5,
         | 
| 31 | 
            +
                "vtm_hard_neg": true
         | 
| 32 | 
            +
              },
         | 
| 33 | 
            +
              "debug": false,
         | 
| 34 | 
            +
              "deep_fusion": false,
         | 
| 35 | 
            +
              "deepspeed": {
         | 
| 36 | 
            +
                "enable": true,
         | 
| 37 | 
            +
                "stage": 1
         | 
| 38 | 
            +
              },
         | 
| 39 | 
            +
              "delete_ds_optim_states": true,
         | 
| 40 | 
            +
              "device": "cuda",
         | 
| 41 | 
            +
              "dist_url": "env://",
         | 
| 42 | 
            +
              "evaluate": false,
         | 
| 43 | 
            +
              "evaluation": {
         | 
| 44 | 
            +
                "eval_frame_ensemble": "concat",
         | 
| 45 | 
            +
                "eval_offload": true,
         | 
| 46 | 
            +
                "eval_x_only": false,
         | 
| 47 | 
            +
                "k_test": 128
         | 
| 48 | 
            +
              },
         | 
| 49 | 
            +
              "gradient_checkpointing": true,
         | 
| 50 | 
            +
              "inputs": {
         | 
| 51 | 
            +
                "batch_size": {
         | 
| 52 | 
            +
                  "image": 64,
         | 
| 53 | 
            +
                  "video": 64
         | 
| 54 | 
            +
                },
         | 
| 55 | 
            +
                "batch_size_test": {
         | 
| 56 | 
            +
                  "image": 4,
         | 
| 57 | 
            +
                  "video": 4
         | 
| 58 | 
            +
                },
         | 
| 59 | 
            +
                "image_res": 224,
         | 
| 60 | 
            +
                "max_txt_l": {
         | 
| 61 | 
            +
                  "image": 32,
         | 
| 62 | 
            +
                  "video": 32
         | 
| 63 | 
            +
                },
         | 
| 64 | 
            +
                "video_input": {
         | 
| 65 | 
            +
                  "num_frames": 8,
         | 
| 66 | 
            +
                  "num_frames_test": 8,
         | 
| 67 | 
            +
                  "random_aug": false,
         | 
| 68 | 
            +
                  "sample_type": "middle",
         | 
| 69 | 
            +
                  "sample_type_test": "middle"
         | 
| 70 | 
            +
                }
         | 
| 71 | 
            +
              },
         | 
| 72 | 
            +
              "jump_evaluate": false,
         | 
| 73 | 
            +
              "log_freq": 100,
         | 
| 74 | 
            +
              "max_txt_l": 32,
         | 
| 75 | 
            +
              "mode": "pt",
         | 
| 76 | 
            +
              "model": {
         | 
| 77 | 
            +
                "embed_dim": 1024,
         | 
| 78 | 
            +
                "find_unused_parameters": false,
         | 
| 79 | 
            +
                "freeze_text": true,
         | 
| 80 | 
            +
                "freeze_vision": true,
         | 
| 81 | 
            +
                "load_vision_ckpt_from_internvideo2_stage2": false,
         | 
| 82 | 
            +
                "model_cls": "InternVideo2_CLIP_small",
         | 
| 83 | 
            +
                "multimodal": {
         | 
| 84 | 
            +
                  "enable": true
         | 
| 85 | 
            +
                },
         | 
| 86 | 
            +
                "open_text_projection": false,
         | 
| 87 | 
            +
                "open_vision_clip_projector": true,
         | 
| 88 | 
            +
                "temp": 0.01,
         | 
| 89 | 
            +
                "temp_min": 0.01,
         | 
| 90 | 
            +
                "text_encoder": {
         | 
| 91 | 
            +
                  "embed_dim": 512,
         | 
| 92 | 
            +
                  "image_cfg": {
         | 
| 93 | 
            +
                    "image_size": 224,
         | 
| 94 | 
            +
                    "model_name": "vit_b16"
         | 
| 95 | 
            +
                  },
         | 
| 96 | 
            +
                  "text_cfg": {
         | 
| 97 | 
            +
                    "causal_masking": true,
         | 
| 98 | 
            +
                    "context_length": 77,
         | 
| 99 | 
            +
                    "dim": 512,
         | 
| 100 | 
            +
                    "ffn_multiplier_per_layer": 4.0,
         | 
| 101 | 
            +
                    "model_name": "base",
         | 
| 102 | 
            +
                    "n_heads_per_layer": 8,
         | 
| 103 | 
            +
                    "n_transformer_layers": 12,
         | 
| 104 | 
            +
                    "norm_layer": "layer_norm_fp32",
         | 
| 105 | 
            +
                    "vocab_size": 49408
         | 
| 106 | 
            +
                  }
         | 
| 107 | 
            +
                },
         | 
| 108 | 
            +
                "vision_encoder": {
         | 
| 109 | 
            +
                  "align_dim": 512,
         | 
| 110 | 
            +
                  "attn_pool_num_heads": 16,
         | 
| 111 | 
            +
                  "checkpoint_num": 0,
         | 
| 112 | 
            +
                  "clip_embed_dim": 768,
         | 
| 113 | 
            +
                  "depth": 24,
         | 
| 114 | 
            +
                  "drop_cls_token": false,
         | 
| 115 | 
            +
                  "drop_path_rate": 0.0,
         | 
| 116 | 
            +
                  "embed_dim": 1024,
         | 
| 117 | 
            +
                  "fused_mlp_heuristic": 1,
         | 
| 118 | 
            +
                  "head_drop_path_rate": 0.0,
         | 
| 119 | 
            +
                  "img_size": 224,
         | 
| 120 | 
            +
                  "in_chans": 3,
         | 
| 121 | 
            +
                  "init_values": 0.1,
         | 
| 122 | 
            +
                  "layerscale_no_force_fp32": true,
         | 
| 123 | 
            +
                  "mlp_ratio": 4,
         | 
| 124 | 
            +
                  "name": "internvideo2_1B",
         | 
| 125 | 
            +
                  "num_frames": 8,
         | 
| 126 | 
            +
                  "num_heads": 16,
         | 
| 127 | 
            +
                  "patch_size": 14,
         | 
| 128 | 
            +
                  "qk_normalization": true,
         | 
| 129 | 
            +
                  "qkv_bias": false,
         | 
| 130 | 
            +
                  "sep_pos_embed": false,
         | 
| 131 | 
            +
                  "tubelet_size": 1,
         | 
| 132 | 
            +
                  "use_checkpoint": false,
         | 
| 133 | 
            +
                  "use_flash_attn": false,
         | 
| 134 | 
            +
                  "use_fused_mlp": false,
         | 
| 135 | 
            +
                  "use_fused_rmsnorm": false
         | 
| 136 | 
            +
                }
         | 
| 137 | 
            +
              },
         | 
| 138 | 
            +
              "model_type": "internvideo2",
         | 
| 139 | 
            +
              "num_frames": 8,
         | 
| 140 | 
            +
              "num_frames_test": 8,
         | 
| 141 | 
            +
              "num_workers": 6,
         | 
| 142 | 
            +
              "optimizer": {
         | 
| 143 | 
            +
                "different_lr": {
         | 
| 144 | 
            +
                  "enable": false,
         | 
| 145 | 
            +
                  "lr": 0.001,
         | 
| 146 | 
            +
                  "module_names": []
         | 
| 147 | 
            +
                },
         | 
| 148 | 
            +
                "lr": 5e-05,
         | 
| 149 | 
            +
                "max_grad_norm": 3.0,
         | 
| 150 | 
            +
                "opt": "adamW",
         | 
| 151 | 
            +
                "opt_betas": [
         | 
| 152 | 
            +
                  0.9,
         | 
| 153 | 
            +
                  0.98
         | 
| 154 | 
            +
                ],
         | 
| 155 | 
            +
                "weight_decay": 0.05
         | 
| 156 | 
            +
              },
         | 
| 157 | 
            +
              "output_dir": null,
         | 
| 158 | 
            +
              "pretrained_path": "",
         | 
| 159 | 
            +
              "resume": false,
         | 
| 160 | 
            +
              "save_ckpt_iter": null,
         | 
| 161 | 
            +
              "save_latest": true,
         | 
| 162 | 
            +
              "scheduler": {
         | 
| 163 | 
            +
                "epochs": 10,
         | 
| 164 | 
            +
                "min_lr_multi": 0.01,
         | 
| 165 | 
            +
                "sched": "cosine",
         | 
| 166 | 
            +
                "warmup_epochs": 1
         | 
| 167 | 
            +
              },
         | 
| 168 | 
            +
              "seed": 42,
         | 
| 169 | 
            +
              "test_file": {
         | 
| 170 | 
            +
                "didemo_ret_test": "available_corpus[\"didemo_ret_test\"]",
         | 
| 171 | 
            +
                "msrvtt_1k_test": "available_corpus[\"msrvtt_1k_test\"]"
         | 
| 172 | 
            +
              },
         | 
| 173 | 
            +
              "test_types": [
         | 
| 174 | 
            +
                "msrvtt_1k_test",
         | 
| 175 | 
            +
                "didemo_ret_test"
         | 
| 176 | 
            +
              ],
         | 
| 177 | 
            +
              "text_enc": "bert_large",
         | 
| 178 | 
            +
              "tokenizer": null,
         | 
| 179 | 
            +
              "torch_dtype": "float16",
         | 
| 180 | 
            +
              "train_file": "available_corpus[\"pretrain_example_data_1B\"]",
         | 
| 181 | 
            +
              "transformers_version": "4.51.3",
         | 
| 182 | 
            +
              "use_bf16": true,
         | 
| 183 | 
            +
              "use_flash_sdp": false,
         | 
| 184 | 
            +
              "use_half_precision": false,
         | 
| 185 | 
            +
              "use_mem_efficient_sdp": false,
         | 
| 186 | 
            +
              "wandb": {
         | 
| 187 | 
            +
                "enable": false,
         | 
| 188 | 
            +
                "entity": "opengvlab",
         | 
| 189 | 
            +
                "project": "InternVideo2-Stage2"
         | 
| 190 | 
            +
              }
         | 
| 191 | 
            +
            }
         | 
    	
        config.py
    ADDED
    
    | @@ -0,0 +1,240 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            class EasyDict(dict):
         | 
| 4 | 
            +
                def __init__(self, d=None, **kwargs):
         | 
| 5 | 
            +
                    if d is None:
         | 
| 6 | 
            +
                        d = {}
         | 
| 7 | 
            +
                    if kwargs:
         | 
| 8 | 
            +
                        d.update(**kwargs)
         | 
| 9 | 
            +
                    for k, v in d.items():
         | 
| 10 | 
            +
                        setattr(self, k, v)
         | 
| 11 | 
            +
                    # Class attributes
         | 
| 12 | 
            +
                    for k in self.__class__.__dict__.keys():
         | 
| 13 | 
            +
                        if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"):
         | 
| 14 | 
            +
                            setattr(self, k, getattr(self, k))
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                def __setattr__(self, name, value):
         | 
| 17 | 
            +
                    if isinstance(value, (list, tuple)):
         | 
| 18 | 
            +
                        value = [self.__class__(x) if isinstance(x, dict) else x for x in value]
         | 
| 19 | 
            +
                    elif isinstance(value, dict) and not isinstance(value, self.__class__):
         | 
| 20 | 
            +
                        value = self.__class__(value)
         | 
| 21 | 
            +
                    super(EasyDict, self).__setattr__(name, value)
         | 
| 22 | 
            +
                    super(EasyDict, self).__setitem__(name, value)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                __setitem__ = __setattr__
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                def update(self, e=None, **f):
         | 
| 27 | 
            +
                    d = e or dict()
         | 
| 28 | 
            +
                    d.update(f)
         | 
| 29 | 
            +
                    for k in d:
         | 
| 30 | 
            +
                        setattr(self, k, d[k])
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                def pop(self, k, d=None):
         | 
| 33 | 
            +
                    if hasattr(self, k):
         | 
| 34 | 
            +
                        delattr(self, k)
         | 
| 35 | 
            +
                    return super(EasyDict, self).pop(k, d)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            class InternVideo2Config(PretrainedConfig):
         | 
| 38 | 
            +
                model_type = "internvideo2"
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                def __init__(self,
         | 
| 41 | 
            +
                             tokenizer=None,
         | 
| 42 | 
            +
                             train_file=None,
         | 
| 43 | 
            +
                             test_file=None,
         | 
| 44 | 
            +
                             test_types=None,
         | 
| 45 | 
            +
                             num_workers=6,
         | 
| 46 | 
            +
                             best_key=None,
         | 
| 47 | 
            +
                             num_frames=8,
         | 
| 48 | 
            +
                             num_frames_test=8,
         | 
| 49 | 
            +
                             batch_size=64,
         | 
| 50 | 
            +
                             batch_size_test=4,
         | 
| 51 | 
            +
                             max_txt_l=32,
         | 
| 52 | 
            +
                             inputs=None,
         | 
| 53 | 
            +
                             text_enc="bert_large",
         | 
| 54 | 
            +
                             model=None,
         | 
| 55 | 
            +
                             criterion=None,
         | 
| 56 | 
            +
                             optimizer=None,
         | 
| 57 | 
            +
                             scheduler=None,
         | 
| 58 | 
            +
                             evaluate=False,
         | 
| 59 | 
            +
                             deep_fusion=False,
         | 
| 60 | 
            +
                             evaluation=None,
         | 
| 61 | 
            +
                             use_half_precision=False,
         | 
| 62 | 
            +
                             use_bf16=True,
         | 
| 63 | 
            +
                             gradient_checkpointing=True,
         | 
| 64 | 
            +
                             use_flash_sdp=False,
         | 
| 65 | 
            +
                             use_mem_efficient_sdp=False,
         | 
| 66 | 
            +
                             compile_model=False,
         | 
| 67 | 
            +
                             wandb=None,
         | 
| 68 | 
            +
                             dist_url="env://",
         | 
| 69 | 
            +
                             device="cuda",
         | 
| 70 | 
            +
                             mode="pt",
         | 
| 71 | 
            +
                             output_dir=None,
         | 
| 72 | 
            +
                             resume=False,
         | 
| 73 | 
            +
                             debug=False,
         | 
| 74 | 
            +
                             log_freq=100,
         | 
| 75 | 
            +
                             seed=42,
         | 
| 76 | 
            +
                             save_latest=True,
         | 
| 77 | 
            +
                             auto_resume=False,
         | 
| 78 | 
            +
                             jump_evaluate=False,
         | 
| 79 | 
            +
                             pretrained_path="",
         | 
| 80 | 
            +
                             save_ckpt_iter=None,
         | 
| 81 | 
            +
                             delete_ds_optim_states=True,
         | 
| 82 | 
            +
                             deepspeed=None,
         | 
| 83 | 
            +
                             **kwargs):
         | 
| 84 | 
            +
                    super().__init__(**kwargs)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    self.tokenizer = tokenizer
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    # Data configuration
         | 
| 89 | 
            +
                    self.train_file = train_file or "available_corpus[\"pretrain_example_data_1B\"]"
         | 
| 90 | 
            +
                    self.test_file = EasyDict(test_file or {
         | 
| 91 | 
            +
                        "msrvtt_1k_test": "available_corpus[\"msrvtt_1k_test\"]",
         | 
| 92 | 
            +
                        "didemo_ret_test": "available_corpus[\"didemo_ret_test\"]"
         | 
| 93 | 
            +
                    })
         | 
| 94 | 
            +
                    self.test_types = test_types or ["msrvtt_1k_test", "didemo_ret_test"]
         | 
| 95 | 
            +
                    self.num_workers = num_workers
         | 
| 96 | 
            +
                    self.best_key = best_key or ["msrvtt_1k_test_match", "t2v_r1"]
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    # Input configuration
         | 
| 99 | 
            +
                    self.num_frames = num_frames
         | 
| 100 | 
            +
                    self.num_frames_test = num_frames_test
         | 
| 101 | 
            +
                    self.batch_size = batch_size
         | 
| 102 | 
            +
                    self.batch_size_test = batch_size_test
         | 
| 103 | 
            +
                    self.max_txt_l = max_txt_l
         | 
| 104 | 
            +
                    self.inputs = EasyDict(inputs or {
         | 
| 105 | 
            +
                        "image_res": 224,
         | 
| 106 | 
            +
                        "video_input": EasyDict({
         | 
| 107 | 
            +
                            "num_frames": num_frames,
         | 
| 108 | 
            +
                            "sample_type": "rand",
         | 
| 109 | 
            +
                            "num_frames_test": num_frames_test,
         | 
| 110 | 
            +
                            "sample_type_test": "middle",
         | 
| 111 | 
            +
                            "random_aug": False
         | 
| 112 | 
            +
                        }),
         | 
| 113 | 
            +
                        "max_txt_l": EasyDict({"image": max_txt_l, "video": max_txt_l}),
         | 
| 114 | 
            +
                        "batch_size": EasyDict({"image": batch_size, "video": batch_size}),
         | 
| 115 | 
            +
                        "batch_size_test": EasyDict({"image": batch_size_test, "video": batch_size_test})
         | 
| 116 | 
            +
                    })
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    # Model configuration
         | 
| 119 | 
            +
                    self.text_enc = text_enc
         | 
| 120 | 
            +
                    self.model = EasyDict(model or {
         | 
| 121 | 
            +
                        "model_cls": "InternVideo2_Stage2",
         | 
| 122 | 
            +
                        "vision_encoder": EasyDict({
         | 
| 123 | 
            +
                            "name": "pretrain_internvideo2_1b_patch14_224",
         | 
| 124 | 
            +
                            "img_size": 224,
         | 
| 125 | 
            +
                            "num_frames": num_frames,
         | 
| 126 | 
            +
                            "tubelet_size": 1,
         | 
| 127 | 
            +
                            "patch_size": 14,
         | 
| 128 | 
            +
                            "d_model": 1408,
         | 
| 129 | 
            +
                            "clip_embed_dim": 768,
         | 
| 130 | 
            +
                            "clip_teacher_embed_dim": 3200,
         | 
| 131 | 
            +
                            "clip_teacher_final_dim": 768,
         | 
| 132 | 
            +
                            "clip_norm_type": "l2",
         | 
| 133 | 
            +
                            "clip_return_layer": 6,
         | 
| 134 | 
            +
                            "clip_student_return_interval": 1,
         | 
| 135 | 
            +
                            "pretrained": None,
         | 
| 136 | 
            +
                            "use_checkpoint": False,
         | 
| 137 | 
            +
                            "checkpoint_num": 40,
         | 
| 138 | 
            +
                            "use_flash_attn": True,
         | 
| 139 | 
            +
                            "use_fused_rmsnorm": True,
         | 
| 140 | 
            +
                            "use_fused_mlp": True,
         | 
| 141 | 
            +
                            "clip_teacher": None,
         | 
| 142 | 
            +
                            "clip_input_resolution": 224,
         | 
| 143 | 
            +
                            "clip_teacher_return_interval": 1,
         | 
| 144 | 
            +
                            "video_mask_type": "random",
         | 
| 145 | 
            +
                            "video_mask_ratio": 0.8,
         | 
| 146 | 
            +
                            "image_mask_type": "random",
         | 
| 147 | 
            +
                            "image_mask_ratio": 0.5,
         | 
| 148 | 
            +
                            "sep_image_video_pos_embed": True,
         | 
| 149 | 
            +
                            "keep_temporal": False,
         | 
| 150 | 
            +
                            "only_mask": True
         | 
| 151 | 
            +
                        }),
         | 
| 152 | 
            +
                        "text_encoder": text_enc,
         | 
| 153 | 
            +
                        "multimodal": EasyDict({"enable": True}),
         | 
| 154 | 
            +
                        "embed_dim": 512,
         | 
| 155 | 
            +
                        "temp": 0.07,
         | 
| 156 | 
            +
                        "find_unused_parameters": False
         | 
| 157 | 
            +
                    })
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    # Criterion configuration
         | 
| 160 | 
            +
                    self.criterion = EasyDict(criterion or {
         | 
| 161 | 
            +
                        "loss_weight": EasyDict({
         | 
| 162 | 
            +
                            "vtc": 1.0,
         | 
| 163 | 
            +
                            "mlm": 1.0,
         | 
| 164 | 
            +
                            "vtm": 1.0,
         | 
| 165 | 
            +
                            "mvm": 0.0,
         | 
| 166 | 
            +
                            "uta": 0.0
         | 
| 167 | 
            +
                        }),
         | 
| 168 | 
            +
                        "vtm_hard_neg": True,
         | 
| 169 | 
            +
                        "mlm_masking_prob": 0.5,
         | 
| 170 | 
            +
                        "distill_final_features": True,
         | 
| 171 | 
            +
                        "clip_loss_ratio": [1.0, 1.0]
         | 
| 172 | 
            +
                    })
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    # Optimizer configuration
         | 
| 175 | 
            +
                    self.optimizer = EasyDict(optimizer or {
         | 
| 176 | 
            +
                        "opt": "adamW",
         | 
| 177 | 
            +
                        "lr": 5e-5,
         | 
| 178 | 
            +
                        "opt_betas": [0.9, 0.98],
         | 
| 179 | 
            +
                        "weight_decay": 0.05,
         | 
| 180 | 
            +
                        "max_grad_norm": 3.0,
         | 
| 181 | 
            +
                        "different_lr": EasyDict({"enable": False, "module_names": [], "lr": 1e-3})
         | 
| 182 | 
            +
                    })
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    # Scheduler configuration
         | 
| 185 | 
            +
                    self.scheduler = EasyDict(scheduler or {
         | 
| 186 | 
            +
                        "sched": "cosine",
         | 
| 187 | 
            +
                        "epochs": 10,
         | 
| 188 | 
            +
                        "min_lr_multi": 0.01,
         | 
| 189 | 
            +
                        "warmup_epochs": 1
         | 
| 190 | 
            +
                    })
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    # Evaluation configuration
         | 
| 193 | 
            +
                    self.evaluate = evaluate
         | 
| 194 | 
            +
                    self.deep_fusion = deep_fusion
         | 
| 195 | 
            +
                    self.evaluation = EasyDict(evaluation or {
         | 
| 196 | 
            +
                        "eval_frame_ensemble": "concat",
         | 
| 197 | 
            +
                        "eval_x_only": False,
         | 
| 198 | 
            +
                        "k_test": 128,
         | 
| 199 | 
            +
                        "eval_offload": True
         | 
| 200 | 
            +
                    })
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    # Miscellaneous
         | 
| 203 | 
            +
                    self.use_half_precision = use_half_precision
         | 
| 204 | 
            +
                    self.use_bf16 = use_bf16
         | 
| 205 | 
            +
                    self.gradient_checkpointing = gradient_checkpointing
         | 
| 206 | 
            +
                    self.use_flash_sdp = use_flash_sdp
         | 
| 207 | 
            +
                    self.use_mem_efficient_sdp = use_mem_efficient_sdp
         | 
| 208 | 
            +
                    self.compile_model = compile_model
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    self.wandb = EasyDict(wandb or {
         | 
| 211 | 
            +
                        "enable": False,
         | 
| 212 | 
            +
                        "entity": "opengvlab",
         | 
| 213 | 
            +
                        "project": "InternVideo2-Stage2"
         | 
| 214 | 
            +
                    })
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    self.dist_url = dist_url
         | 
| 217 | 
            +
                    self.device = device
         | 
| 218 | 
            +
                    self.mode = mode
         | 
| 219 | 
            +
                    self.output_dir = output_dir
         | 
| 220 | 
            +
                    self.resume = resume
         | 
| 221 | 
            +
                    self.debug = debug
         | 
| 222 | 
            +
                    self.log_freq = log_freq
         | 
| 223 | 
            +
                    self.seed = seed
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    self.save_latest = save_latest
         | 
| 226 | 
            +
                    self.auto_resume = auto_resume
         | 
| 227 | 
            +
                    self.jump_evaluate = jump_evaluate
         | 
| 228 | 
            +
                    self.pretrained_path = pretrained_path
         | 
| 229 | 
            +
                    self.save_ckpt_iter = save_ckpt_iter
         | 
| 230 | 
            +
                    self.delete_ds_optim_states = delete_ds_optim_states
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                    self.deepspeed = EasyDict(deepspeed or {
         | 
| 233 | 
            +
                        "enable": True,
         | 
| 234 | 
            +
                        "stage": 1
         | 
| 235 | 
            +
                    })
         | 
| 236 | 
            +
                def set_num_frames(self, num_frames):
         | 
| 237 | 
            +
                    # print('Here ', num_frames)
         | 
| 238 | 
            +
                    self.num_frames = num_frames
         | 
| 239 | 
            +
                    self.inputs.video_input.num_frames = num_frames
         | 
| 240 | 
            +
                    self.model.vision_encoder.num_frames = num_frames
         | 
    	
        demo.py
    ADDED
    
    | @@ -0,0 +1,143 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import random
         | 
| 3 | 
            +
            import io
         | 
| 4 | 
            +
            import av
         | 
| 5 | 
            +
            import cv2
         | 
| 6 | 
            +
            import decord
         | 
| 7 | 
            +
            import imageio
         | 
| 8 | 
            +
            from decord import VideoReader
         | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            import numpy as np
         | 
| 11 | 
            +
            import math
         | 
| 12 | 
            +
            import torch.nn.functional as F
         | 
| 13 | 
            +
            decord.bridge.set_bridge("torch")
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from transformers import AutoConfig, AutoModel
         | 
| 16 | 
            +
            config = AutoConfig.from_pretrained("/fs-computility/video/heyinan/iv2hf/", trust_remote_code=True)
         | 
| 17 | 
            +
            model = AutoModel.from_pretrained("/fs-computility/video/heyinan/iv2hf/", trust_remote_code=True).to(config.device)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1, start=None, end=None):
         | 
| 21 | 
            +
                start_frame, end_frame = 0, vlen
         | 
| 22 | 
            +
                if start is not None:
         | 
| 23 | 
            +
                    start_frame = max(start_frame,int(start * input_fps))
         | 
| 24 | 
            +
                if end is not None:
         | 
| 25 | 
            +
                    end_frame = min(end_frame,int(end * input_fps))
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                # Ensure start_frame is less than end_frame
         | 
| 28 | 
            +
                if start_frame >= end_frame:
         | 
| 29 | 
            +
                    raise ValueError("Start frame index must be less than end frame index")
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                # Calculate the length of the clip in frames
         | 
| 32 | 
            +
                clip_length = end_frame - start_frame
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                if sample in ["rand", "middle"]:  # uniform sampling
         | 
| 35 | 
            +
                    acc_samples = min(num_frames, clip_length)
         | 
| 36 | 
            +
                    # split the clip into `acc_samples` intervals, and sample from each interval.
         | 
| 37 | 
            +
                    intervals = np.linspace(start=start_frame, stop=end_frame, num=acc_samples + 1).astype(int)
         | 
| 38 | 
            +
                    ranges = []
         | 
| 39 | 
            +
                    for idx, interv in enumerate(intervals[:-1]):
         | 
| 40 | 
            +
                        ranges.append((interv, intervals[idx + 1] - 1))
         | 
| 41 | 
            +
                    if sample == 'rand':
         | 
| 42 | 
            +
                        try:
         | 
| 43 | 
            +
                            frame_indices = [random.choice(range(x[0], x[1] + 1)) for x in ranges]
         | 
| 44 | 
            +
                        except:
         | 
| 45 | 
            +
                            frame_indices = np.random.permutation(clip_length)[:acc_samples] + start_frame
         | 
| 46 | 
            +
                            frame_indices.sort()
         | 
| 47 | 
            +
                            frame_indices = list(frame_indices)
         | 
| 48 | 
            +
                    elif fix_start is not None:
         | 
| 49 | 
            +
                        frame_indices = [x[0] + fix_start for x in ranges]
         | 
| 50 | 
            +
                    elif sample == 'middle':
         | 
| 51 | 
            +
                        frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
         | 
| 52 | 
            +
                    else:
         | 
| 53 | 
            +
                        raise NotImplementedError
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    if len(frame_indices) < num_frames:  # padded with last frame
         | 
| 56 | 
            +
                        padded_frame_indices = [frame_indices[-1]] * num_frames
         | 
| 57 | 
            +
                        padded_frame_indices[:len(frame_indices)] = frame_indices
         | 
| 58 | 
            +
                        frame_indices = padded_frame_indices
         | 
| 59 | 
            +
                elif "fps" in sample:  # fps0.5, sequentially sample frames at 0.5 fps
         | 
| 60 | 
            +
                    output_fps = float(sample[3:])
         | 
| 61 | 
            +
                    duration = float(clip_length) / input_fps
         | 
| 62 | 
            +
                    delta = 1 / output_fps  # gap between frames, this is also the clip length each frame represents
         | 
| 63 | 
            +
                    frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
         | 
| 64 | 
            +
                    frame_indices = np.around(frame_seconds * input_fps).astype(int) + start_frame
         | 
| 65 | 
            +
                    frame_indices = [e for e in frame_indices if e < end_frame]
         | 
| 66 | 
            +
                    if max_num_frames > 0 and len(frame_indices) > max_num_frames:
         | 
| 67 | 
            +
                        frame_indices = frame_indices[:max_num_frames]
         | 
| 68 | 
            +
                        # frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)
         | 
| 69 | 
            +
                else:
         | 
| 70 | 
            +
                    raise ValueError
         | 
| 71 | 
            +
                return frame_indices
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            def read_frames_decord(
         | 
| 74 | 
            +
                    video_path, num_frames, sample='middle', fix_start=None, 
         | 
| 75 | 
            +
                    max_num_frames=-1, client=None, trimmed30=False, start=None, end=None
         | 
| 76 | 
            +
                ):
         | 
| 77 | 
            +
                num_threads = 1 if video_path.endswith('.webm') else 0 # make ssv2 happy
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                video_reader = VideoReader(video_path, num_threads=num_threads)
         | 
| 80 | 
            +
                vlen = len(video_reader)
         | 
| 81 | 
            +
             
         | 
| 82 | 
            +
                fps = video_reader.get_avg_fps()
         | 
| 83 | 
            +
                duration = vlen / float(fps)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                frame_indices = get_frame_indices(
         | 
| 86 | 
            +
                    num_frames, vlen, sample=sample, fix_start=fix_start,
         | 
| 87 | 
            +
                    input_fps=fps, max_num_frames=max_num_frames, start=start, end=end
         | 
| 88 | 
            +
                )
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                frames = video_reader.get_batch(frame_indices)  # (T, H, W, C), torch.uint8
         | 
| 91 | 
            +
                frames = frames.permute(0, 3, 1, 2)  # (T, C, H, W), torch.uint8
         | 
| 92 | 
            +
                return frames, frame_indices, duration
         | 
| 93 | 
            +
             | 
| 94 | 
            +
            def get_text_feature(model, texts):
         | 
| 95 | 
            +
                text_input = model.tokenizer(texts).to(model.device)
         | 
| 96 | 
            +
                text_features = model.encode_text(text_input)
         | 
| 97 | 
            +
                return text_features
         | 
| 98 | 
            +
                
         | 
| 99 | 
            +
            def get_similarity(video_feature, text_feature):
         | 
| 100 | 
            +
                video_feature = F.normalize(video_feature, dim=-1)
         | 
| 101 | 
            +
                text_feature = F.normalize(text_feature, dim=-1)
         | 
| 102 | 
            +
                sim_matrix = text_feature @ video_feature.T
         | 
| 103 | 
            +
                return sim_matrix
         | 
| 104 | 
            +
             | 
| 105 | 
            +
            def get_top_videos(model, text_features, video_features, video_paths, texts):
         | 
| 106 | 
            +
                # text_features = get_text_feature(texts)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                video_features = F.normalize(video_features, dim=-1)
         | 
| 109 | 
            +
                text_features = F.normalize(text_features, dim=-1)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                # print(text_features.shape, video_features.shape)
         | 
| 112 | 
            +
                sim_matrix = text_features @ video_features.T
         | 
| 113 | 
            +
                # print(sim_matrix.shape)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                top_k = 5
         | 
| 116 | 
            +
                sim_matrix_top_k = torch.topk(sim_matrix, top_k, dim=1)[1]
         | 
| 117 | 
            +
                softmax_sim_matrix = F.softmax(sim_matrix, dim=1)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                retrieval_infos = {}
         | 
| 120 | 
            +
                for i in range(len(sim_matrix_top_k)):
         | 
| 121 | 
            +
                    print("\n",texts[i])
         | 
| 122 | 
            +
                    retrieval_infos[texts[i]] = []
         | 
| 123 | 
            +
                    for j in range(top_k):
         | 
| 124 | 
            +
                        print("top", j+1, ":", video_paths[sim_matrix_top_k[i][j]], "~prob:", sim_matrix[i][sim_matrix_top_k[i][j]].item())
         | 
| 125 | 
            +
                        retrieval_infos[texts[i]].append({"video":  video_paths[sim_matrix_top_k[i][j]], "prob": sim_matrix[i][sim_matrix_top_k[i][j]].item(), "rank": j+1})
         | 
| 126 | 
            +
                return retrieval_infos
         | 
| 127 | 
            +
             | 
| 128 | 
            +
            if __name__=="__main__":
         | 
| 129 | 
            +
                video_features = []
         | 
| 130 | 
            +
                demo_videos = ["video1.mp4","video2.mp4"]
         | 
| 131 | 
            +
                texts = ['a person talking', 'a logo', 'a building']
         | 
| 132 | 
            +
                for video_path in demo_videos:
         | 
| 133 | 
            +
                    frames, frame_indices, video_duration = read_frames_decord(video_path,8)
         | 
| 134 | 
            +
                    frames = model.transform(frames).unsqueeze(0).to(model.device)
         | 
| 135 | 
            +
                    with torch.no_grad():
         | 
| 136 | 
            +
                        video_feature = model.encode_vision(frames, test=True)
         | 
| 137 | 
            +
                        video_features.append(video_feature)
         | 
| 138 | 
            +
                
         | 
| 139 | 
            +
                text_features = get_text_feature(model, texts)
         | 
| 140 | 
            +
                video_features = torch.cat(video_features, dim=0).to(text_features.dtype).to(config.device)
         | 
| 141 | 
            +
                results = get_top_videos(model, text_features, video_features, demo_videos, texts)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
             | 
    	
        flash_attention_class.py
    ADDED
    
    | @@ -0,0 +1,74 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from einops import rearrange
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
         | 
| 7 | 
            +
            from flash_attn.bert_padding import unpad_input, pad_input
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class FlashAttention(nn.Module):
         | 
| 11 | 
            +
                """Implement the scaled dot product attention with softmax.
         | 
| 12 | 
            +
                Arguments
         | 
| 13 | 
            +
                ---------
         | 
| 14 | 
            +
                    softmax_scale: The temperature to use for the softmax attention.
         | 
| 15 | 
            +
                                  (default: 1/sqrt(d_keys) where d_keys is computed at
         | 
| 16 | 
            +
                                  runtime)
         | 
| 17 | 
            +
                    attention_dropout: The dropout rate to apply to the attention
         | 
| 18 | 
            +
                                       (default: 0.0)
         | 
| 19 | 
            +
                """
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
         | 
| 22 | 
            +
                    super().__init__()
         | 
| 23 | 
            +
                    self.softmax_scale = softmax_scale
         | 
| 24 | 
            +
                    self.dropout_p = attention_dropout
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
         | 
| 27 | 
            +
                            max_s=None, need_weights=False):
         | 
| 28 | 
            +
                    """Implements the multihead softmax attention.
         | 
| 29 | 
            +
                    Arguments
         | 
| 30 | 
            +
                    ---------
         | 
| 31 | 
            +
                        qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
         | 
| 32 | 
            +
                            if unpadded: (nnz, 3, h, d)
         | 
| 33 | 
            +
                        key_padding_mask: a bool tensor of shape (B, S)
         | 
| 34 | 
            +
                    """
         | 
| 35 | 
            +
                    
         | 
| 36 | 
            +
                    # qkv = qkv.to(torch.float16)
         | 
| 37 | 
            +
                    
         | 
| 38 | 
            +
                    assert not need_weights
         | 
| 39 | 
            +
                    assert qkv.dtype in [torch.float16, torch.bfloat16]
         | 
| 40 | 
            +
                    assert qkv.is_cuda
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    if cu_seqlens is None:
         | 
| 43 | 
            +
                        batch_size = qkv.shape[0]
         | 
| 44 | 
            +
                        seqlen = qkv.shape[1]
         | 
| 45 | 
            +
                        if key_padding_mask is None:
         | 
| 46 | 
            +
                            qkv = rearrange(qkv, 'b s ... -> (b s) ...')
         | 
| 47 | 
            +
                            max_s = seqlen
         | 
| 48 | 
            +
                            cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
         | 
| 49 | 
            +
                                                      device=qkv.device)
         | 
| 50 | 
            +
                            output = flash_attn_varlen_qkvpacked_func(
         | 
| 51 | 
            +
                                qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
         | 
| 52 | 
            +
                                softmax_scale=self.softmax_scale, causal=causal
         | 
| 53 | 
            +
                            )
         | 
| 54 | 
            +
                            output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
         | 
| 55 | 
            +
                        else:
         | 
| 56 | 
            +
                            nheads = qkv.shape[-2]
         | 
| 57 | 
            +
                            x = rearrange(qkv, 'b s three h d -> b s (three h d)')
         | 
| 58 | 
            +
                            x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
         | 
| 59 | 
            +
                            x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
         | 
| 60 | 
            +
                            output_unpad = flash_attn_varlen_qkvpacked_func(
         | 
| 61 | 
            +
                                x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
         | 
| 62 | 
            +
                                softmax_scale=self.softmax_scale, causal=causal
         | 
| 63 | 
            +
                            )
         | 
| 64 | 
            +
                            output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
         | 
| 65 | 
            +
                                                         indices, batch_size, seqlen),
         | 
| 66 | 
            +
                                               'b s (h d) -> b s h d', h=nheads)
         | 
| 67 | 
            +
                    else:
         | 
| 68 | 
            +
                        assert max_s is not None
         | 
| 69 | 
            +
                        output = flash_attn_varlen_qkvpacked_func(
         | 
| 70 | 
            +
                            qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
         | 
| 71 | 
            +
                            softmax_scale=self.softmax_scale, causal=causal
         | 
| 72 | 
            +
                        )
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    return output, None
         | 
    	
        internvideo2.py
    ADDED
    
    | @@ -0,0 +1,779 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
            from timm.models.layers import DropPath, to_2tuple, trunc_normal_
         | 
| 5 | 
            +
            from torch import nn
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import torch.utils.checkpoint as checkpoint
         | 
| 8 | 
            +
            from functools import partial
         | 
| 9 | 
            +
            from einops import rearrange
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from .pos_embed import get_3d_sincos_pos_embed, get_2d_sincos_pos_embed, get_1d_sincos_pos_embed, interpolate_pos_embed_internvideo2
         | 
| 12 | 
            +
            from .flash_attention_class import FlashAttention
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from transformers.utils import logging as error_logging
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            # Set up logging
         | 
| 17 | 
            +
            error_logging.set_verbosity_error()
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            try:
         | 
| 20 | 
            +
                from flash_attn.modules.mlp import Mlp as FusedMLP
         | 
| 21 | 
            +
            except:
         | 
| 22 | 
            +
                pass
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            try:
         | 
| 25 | 
            +
                from flash_attn.ops.rms_norm import DropoutAddRMSNorm
         | 
| 26 | 
            +
            except:
         | 
| 27 | 
            +
                pass
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            class CrossAttention(nn.Module):
         | 
| 31 | 
            +
                def __init__(
         | 
| 32 | 
            +
                        self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
         | 
| 33 | 
            +
                        proj_drop=0., attn_head_dim=None, out_dim=None):
         | 
| 34 | 
            +
                    super().__init__()
         | 
| 35 | 
            +
                    if out_dim is None:
         | 
| 36 | 
            +
                        out_dim = dim
         | 
| 37 | 
            +
                    self.num_heads = num_heads
         | 
| 38 | 
            +
                    head_dim = dim // num_heads
         | 
| 39 | 
            +
                    if attn_head_dim is not None:
         | 
| 40 | 
            +
                        head_dim = attn_head_dim
         | 
| 41 | 
            +
                    all_head_dim = head_dim * self.num_heads
         | 
| 42 | 
            +
                    self.scale = qk_scale or head_dim ** -0.5
         | 
| 43 | 
            +
                    assert all_head_dim == dim
         | 
| 44 | 
            +
                    
         | 
| 45 | 
            +
                    self.q = nn.Linear(dim, all_head_dim, bias=False)
         | 
| 46 | 
            +
                    self.k = nn.Linear(dim, all_head_dim, bias=False)
         | 
| 47 | 
            +
                    self.v = nn.Linear(dim, all_head_dim, bias=False)
         | 
| 48 | 
            +
                    
         | 
| 49 | 
            +
                    if qkv_bias:
         | 
| 50 | 
            +
                        self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
         | 
| 51 | 
            +
                        self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
         | 
| 52 | 
            +
                        self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
         | 
| 53 | 
            +
                    else:
         | 
| 54 | 
            +
                        self.q_bias = None
         | 
| 55 | 
            +
                        self.k_bias = None
         | 
| 56 | 
            +
                        self.v_bias = None
         | 
| 57 | 
            +
                    
         | 
| 58 | 
            +
                    self.attn_drop = nn.Dropout(attn_drop)
         | 
| 59 | 
            +
                    self.proj = nn.Linear(all_head_dim, out_dim)
         | 
| 60 | 
            +
                    self.proj_drop = nn.Dropout(proj_drop)
         | 
| 61 | 
            +
                
         | 
| 62 | 
            +
                def forward(self, x, k=None, v=None):
         | 
| 63 | 
            +
                    B, N, C = x.shape
         | 
| 64 | 
            +
                    N_k = k.shape[1]
         | 
| 65 | 
            +
                    N_v = v.shape[1]
         | 
| 66 | 
            +
                    
         | 
| 67 | 
            +
                    q_bias, k_bias, v_bias = None, None, None
         | 
| 68 | 
            +
                    if self.q_bias is not None:
         | 
| 69 | 
            +
                        q_bias = self.q_bias
         | 
| 70 | 
            +
                        k_bias = self.k_bias
         | 
| 71 | 
            +
                        v_bias = self.v_bias
         | 
| 72 | 
            +
                    
         | 
| 73 | 
            +
                    q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
         | 
| 74 | 
            +
                    q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)  # (B, N_head, N_q, dim)
         | 
| 75 | 
            +
                    
         | 
| 76 | 
            +
                    k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
         | 
| 77 | 
            +
                    k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
         | 
| 78 | 
            +
                    
         | 
| 79 | 
            +
                    v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
         | 
| 80 | 
            +
                    v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
         | 
| 81 | 
            +
                    
         | 
| 82 | 
            +
                    q = q * self.scale
         | 
| 83 | 
            +
                    attn = (q @ k.transpose(-2, -1))  # (B, N_head, N_q, N_k)
         | 
| 84 | 
            +
                    
         | 
| 85 | 
            +
                    attn = attn.softmax(dim=-1)
         | 
| 86 | 
            +
                    attn = self.attn_drop(attn)
         | 
| 87 | 
            +
                    
         | 
| 88 | 
            +
                    x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
         | 
| 89 | 
            +
                    x = self.proj(x)
         | 
| 90 | 
            +
                    x = self.proj_drop(x)
         | 
| 91 | 
            +
                    
         | 
| 92 | 
            +
                    return x
         | 
| 93 | 
            +
             | 
| 94 | 
            +
             | 
| 95 | 
            +
            class AttentiveBlock(nn.Module):
         | 
| 96 | 
            +
                
         | 
| 97 | 
            +
                def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
         | 
| 98 | 
            +
                             drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None):
         | 
| 99 | 
            +
                    super().__init__()
         | 
| 100 | 
            +
                    
         | 
| 101 | 
            +
                    self.norm1_q = norm_layer(dim)
         | 
| 102 | 
            +
                    self.norm1_k = norm_layer(dim)
         | 
| 103 | 
            +
                    self.norm1_v = norm_layer(dim)
         | 
| 104 | 
            +
                    self.cross_attn = CrossAttention(
         | 
| 105 | 
            +
                        dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
         | 
| 106 | 
            +
                        proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim)
         | 
| 107 | 
            +
                    
         | 
| 108 | 
            +
                    self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
         | 
| 109 | 
            +
                
         | 
| 110 | 
            +
                def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None):
         | 
| 111 | 
            +
                    x_q = self.norm1_q(x_q + pos_q)
         | 
| 112 | 
            +
                    x_k = self.norm1_k(x_kv + pos_k)
         | 
| 113 | 
            +
                    x_v = self.norm1_v(x_kv)
         | 
| 114 | 
            +
                    x = self.cross_attn(x_q, k=x_k, v=x_v)
         | 
| 115 | 
            +
                    
         | 
| 116 | 
            +
                    return x
         | 
| 117 | 
            +
             | 
| 118 | 
            +
             | 
| 119 | 
            +
            class AttentionPoolingBlock(AttentiveBlock):
         | 
| 120 | 
            +
                
         | 
| 121 | 
            +
                def forward(self, x):
         | 
| 122 | 
            +
                    # x_q = x.mean(1, keepdim=True)
         | 
| 123 | 
            +
                    x_q = x
         | 
| 124 | 
            +
                    x_kv, pos_q, pos_k = x, 0, 0
         | 
| 125 | 
            +
                    x = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None)
         | 
| 126 | 
            +
                    x = x.squeeze(1)
         | 
| 127 | 
            +
                    return x
         | 
| 128 | 
            +
             | 
| 129 | 
            +
             | 
| 130 | 
            +
            class RMSNorm(nn.Module):
         | 
| 131 | 
            +
                def __init__(self, hidden_size, eps=1e-6):
         | 
| 132 | 
            +
                    super().__init__()
         | 
| 133 | 
            +
                    self.weight = nn.Parameter(torch.ones(hidden_size))
         | 
| 134 | 
            +
                    self.variance_epsilon = eps
         | 
| 135 | 
            +
                
         | 
| 136 | 
            +
                def forward(self, hidden_states):
         | 
| 137 | 
            +
                    input_dtype = hidden_states.dtype
         | 
| 138 | 
            +
                    hidden_states = hidden_states.to(torch.float32)
         | 
| 139 | 
            +
                    variance = hidden_states.pow(2).mean(-1, keepdim=True)
         | 
| 140 | 
            +
                    hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
         | 
| 141 | 
            +
                    return self.weight * hidden_states.to(input_dtype)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
             | 
| 144 | 
            +
            class LayerScale(nn.Module):
         | 
| 145 | 
            +
                def __init__(self, dim, init_values=1e-5, inplace=False, force_fp32=False):
         | 
| 146 | 
            +
                    super().__init__()
         | 
| 147 | 
            +
                    self.inplace = inplace
         | 
| 148 | 
            +
                    self.gamma = nn.Parameter(init_values * torch.ones(dim))
         | 
| 149 | 
            +
                    self.force_fp32 = force_fp32
         | 
| 150 | 
            +
                
         | 
| 151 | 
            +
                @torch.cuda.amp.autocast(enabled=False)
         | 
| 152 | 
            +
                def forward(self, x):
         | 
| 153 | 
            +
                    if self.force_fp32:
         | 
| 154 | 
            +
                        output_type = x.dtype
         | 
| 155 | 
            +
                        out = x.float().mul_(self.gamma.float()) if self.inplace else x.float() * self.gamma.float()
         | 
| 156 | 
            +
                        return out.to(dtype=output_type)
         | 
| 157 | 
            +
                    else:
         | 
| 158 | 
            +
                        out = x.mul_(self.gamma) if self.inplace else x * self.gamma
         | 
| 159 | 
            +
                        return out
         | 
| 160 | 
            +
             | 
| 161 | 
            +
             | 
| 162 | 
            +
            class Attention(nn.Module):
         | 
| 163 | 
            +
                def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_flash_attn=False,
         | 
| 164 | 
            +
                             causal=False, norm_layer=nn.LayerNorm, qk_normalization=False, use_fused_rmsnorm=False):
         | 
| 165 | 
            +
                    super().__init__()
         | 
| 166 | 
            +
                    assert dim % num_heads == 0, 'dim should be divisible by num_heads'
         | 
| 167 | 
            +
                    self.num_heads = num_heads
         | 
| 168 | 
            +
                    head_dim = dim // num_heads
         | 
| 169 | 
            +
                    self.scale = head_dim ** -0.5
         | 
| 170 | 
            +
                    
         | 
| 171 | 
            +
                    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
         | 
| 172 | 
            +
                    self.attn_drop = nn.Dropout(attn_drop)
         | 
| 173 | 
            +
                    self.proj = nn.Linear(dim, dim)
         | 
| 174 | 
            +
                    self.proj_drop = nn.Dropout(proj_drop)
         | 
| 175 | 
            +
                    
         | 
| 176 | 
            +
                    self.use_flash_attn = use_flash_attn
         | 
| 177 | 
            +
                    if use_flash_attn:
         | 
| 178 | 
            +
                        self.causal = causal
         | 
| 179 | 
            +
                        self.inner_attn = FlashAttention(attention_dropout=attn_drop)
         | 
| 180 | 
            +
                    
         | 
| 181 | 
            +
                    self.qk_normalization = qk_normalization
         | 
| 182 | 
            +
                    self.q_norm = norm_layer(dim) if qk_normalization else nn.Identity()
         | 
| 183 | 
            +
                    self.k_norm = norm_layer(dim) if qk_normalization else nn.Identity()
         | 
| 184 | 
            +
                    self.use_fused_rmsnorm = use_fused_rmsnorm
         | 
| 185 | 
            +
                
         | 
| 186 | 
            +
                def _naive_attn(self, x):
         | 
| 187 | 
            +
                    B, N, C = x.shape
         | 
| 188 | 
            +
                    # print(x.shape, torch.cuda.memory_allocated(), torch.cuda.memory_allocated())
         | 
| 189 | 
            +
                    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
         | 
| 190 | 
            +
                    q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)
         | 
| 191 | 
            +
                    
         | 
| 192 | 
            +
                    if self.qk_normalization:
         | 
| 193 | 
            +
                        B_, H_, N_, D_ = q.shape
         | 
| 194 | 
            +
                        q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
         | 
| 195 | 
            +
                        k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
         | 
| 196 | 
            +
                    
         | 
| 197 | 
            +
                    attn = ((q * self.scale) @ k.transpose(-2, -1))
         | 
| 198 | 
            +
                    # attn = attn - attn.max(-1)[0].unsqueeze(-1)  # in case of overflow for fp16
         | 
| 199 | 
            +
                    attn = attn.softmax(dim=-1)
         | 
| 200 | 
            +
                    attn = self.attn_drop(attn)
         | 
| 201 | 
            +
                    # print(torch.cuda.memory_allocated(), torch.cuda.memory_allocated())
         | 
| 202 | 
            +
                    x = (attn @ v).transpose(1, 2).reshape(B, N, C)
         | 
| 203 | 
            +
                    x = self.proj(x)
         | 
| 204 | 
            +
                    x = self.proj_drop(x)
         | 
| 205 | 
            +
                    return x
         | 
| 206 | 
            +
                
         | 
| 207 | 
            +
                def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
         | 
| 208 | 
            +
                    
         | 
| 209 | 
            +
                    qkv = self.qkv(x)
         | 
| 210 | 
            +
                    qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
         | 
| 211 | 
            +
                    
         | 
| 212 | 
            +
                    if self.qk_normalization:
         | 
| 213 | 
            +
                        q, k, v = qkv.unbind(2)
         | 
| 214 | 
            +
                        if self.use_fused_rmsnorm:
         | 
| 215 | 
            +
                            q = self.q_norm(q.flatten(-2, -1))[0].view(q.shape)
         | 
| 216 | 
            +
                            k = self.k_norm(k.flatten(-2, -1))[0].view(k.shape)
         | 
| 217 | 
            +
                        else:
         | 
| 218 | 
            +
                            q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
         | 
| 219 | 
            +
                            k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
         | 
| 220 | 
            +
                        qkv = torch.stack([q, k, v], dim=2)
         | 
| 221 | 
            +
                    
         | 
| 222 | 
            +
                    context, _ = self.inner_attn(
         | 
| 223 | 
            +
                        qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal
         | 
| 224 | 
            +
                    )
         | 
| 225 | 
            +
                    outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
         | 
| 226 | 
            +
                    outs = self.proj_drop(outs)
         | 
| 227 | 
            +
                    return outs
         | 
| 228 | 
            +
                
         | 
| 229 | 
            +
                def forward(self, x):
         | 
| 230 | 
            +
                    x = self._naive_attn(x) if not self.use_flash_attn else self._flash_attn(x)
         | 
| 231 | 
            +
                    return x
         | 
| 232 | 
            +
             | 
| 233 | 
            +
             | 
| 234 | 
            +
            class Mlp(nn.Module):
         | 
| 235 | 
            +
                """ MLP as used in Vision Transformer, MLP-Mixer and related networks
         | 
| 236 | 
            +
                """
         | 
| 237 | 
            +
                
         | 
| 238 | 
            +
                def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
         | 
| 239 | 
            +
                             bias=True, drop=0.):
         | 
| 240 | 
            +
                    super().__init__()
         | 
| 241 | 
            +
                    out_features = out_features or in_features
         | 
| 242 | 
            +
                    hidden_features = hidden_features or in_features
         | 
| 243 | 
            +
                    bias = to_2tuple(bias)
         | 
| 244 | 
            +
                    drop_probs = to_2tuple(drop)
         | 
| 245 | 
            +
                    
         | 
| 246 | 
            +
                    self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
         | 
| 247 | 
            +
                    self.act = act_layer()
         | 
| 248 | 
            +
                    self.drop1 = nn.Dropout(drop_probs[0])
         | 
| 249 | 
            +
                    self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
         | 
| 250 | 
            +
                    self.drop2 = nn.Dropout(drop_probs[1])
         | 
| 251 | 
            +
                
         | 
| 252 | 
            +
                def forward(self, x):
         | 
| 253 | 
            +
                    x = self.fc1(x)
         | 
| 254 | 
            +
                    x = self.act(x)
         | 
| 255 | 
            +
                    x = self.drop1(x)
         | 
| 256 | 
            +
                    x = self.fc2(x)
         | 
| 257 | 
            +
                    x = self.drop2(x)
         | 
| 258 | 
            +
                    return x
         | 
| 259 | 
            +
             | 
| 260 | 
            +
             | 
| 261 | 
            +
            class Block(nn.Module):
         | 
| 262 | 
            +
                
         | 
| 263 | 
            +
                def __init__(
         | 
| 264 | 
            +
                        self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
         | 
| 265 | 
            +
                        drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_flash_attn=False, use_fused_mlp=False,
         | 
| 266 | 
            +
                        fused_mlp_heuristic=1, with_cp=False, qk_normalization=False, layerscale_no_force_fp32=False,
         | 
| 267 | 
            +
                        use_fused_rmsnorm=False):
         | 
| 268 | 
            +
                    super().__init__()
         | 
| 269 | 
            +
                    
         | 
| 270 | 
            +
                    self.norm1 = norm_layer(dim)
         | 
| 271 | 
            +
                    self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
         | 
| 272 | 
            +
                                          use_flash_attn=use_flash_attn, causal=False, norm_layer=norm_layer,
         | 
| 273 | 
            +
                                          qk_normalization=qk_normalization,
         | 
| 274 | 
            +
                                          use_fused_rmsnorm=use_fused_rmsnorm)
         | 
| 275 | 
            +
                    self.ls1 = LayerScale(dim, init_values=init_values,
         | 
| 276 | 
            +
                                          force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
         | 
| 277 | 
            +
                    # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
         | 
| 278 | 
            +
                    self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
         | 
| 279 | 
            +
                    
         | 
| 280 | 
            +
                    self.norm2 = norm_layer(dim)
         | 
| 281 | 
            +
                    mlp_hidden_dim = int(dim * mlp_ratio)
         | 
| 282 | 
            +
                    if use_fused_mlp:
         | 
| 283 | 
            +
                        # self.mlp = FusedMLP(in_features=dim, hidden_features=mlp_hidden_dim, heuristic=fused_mlp_heuristic)
         | 
| 284 | 
            +
                        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
         | 
| 285 | 
            +
                    else:
         | 
| 286 | 
            +
                        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
         | 
| 287 | 
            +
                    self.ls2 = LayerScale(dim, init_values=init_values,
         | 
| 288 | 
            +
                                          force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
         | 
| 289 | 
            +
                    self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
         | 
| 290 | 
            +
                    
         | 
| 291 | 
            +
                    self.with_cp = with_cp
         | 
| 292 | 
            +
                    self.use_fused_rmsnorm = use_fused_rmsnorm
         | 
| 293 | 
            +
                
         | 
| 294 | 
            +
                def forward(self, x, residual=None):
         | 
| 295 | 
            +
                    
         | 
| 296 | 
            +
                    def _inner_forward(x, residual=None):
         | 
| 297 | 
            +
                        if self.use_fused_rmsnorm:
         | 
| 298 | 
            +
                            x, residual = self.norm1(x, residual)
         | 
| 299 | 
            +
                            x = self.drop_path1(self.ls1(self.attn(x)))
         | 
| 300 | 
            +
                            x, residual = self.norm2(x, residual)
         | 
| 301 | 
            +
                            x = self.drop_path2(self.ls2(self.mlp(x)))
         | 
| 302 | 
            +
                            return x, residual
         | 
| 303 | 
            +
                        else:
         | 
| 304 | 
            +
                            assert residual is None
         | 
| 305 | 
            +
                            x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
         | 
| 306 | 
            +
                            x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
         | 
| 307 | 
            +
                            return x
         | 
| 308 | 
            +
                    
         | 
| 309 | 
            +
                    if self.with_cp:
         | 
| 310 | 
            +
                        # print(f"\033[31m use_checkpoint [0m")
         | 
| 311 | 
            +
                        return checkpoint.checkpoint(_inner_forward, x, residual)
         | 
| 312 | 
            +
                    else:
         | 
| 313 | 
            +
                        return _inner_forward(x, residual=residual)
         | 
| 314 | 
            +
             | 
| 315 | 
            +
             | 
| 316 | 
            +
            class PatchEmbed(nn.Module):
         | 
| 317 | 
            +
                """ 3D Image to Patch Embedding
         | 
| 318 | 
            +
                """
         | 
| 319 | 
            +
                
         | 
| 320 | 
            +
                def __init__(
         | 
| 321 | 
            +
                        self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, 
         | 
| 322 | 
            +
                        num_frames=8, tubelet_size=1, norm_layer=None
         | 
| 323 | 
            +
                    ):
         | 
| 324 | 
            +
                    super().__init__()
         | 
| 325 | 
            +
                    img_size = to_2tuple(img_size)
         | 
| 326 | 
            +
                    patch_size = to_2tuple(patch_size)
         | 
| 327 | 
            +
                    self.img_size = img_size
         | 
| 328 | 
            +
                    self.patch_size = patch_size
         | 
| 329 | 
            +
                    self.grid_size = (
         | 
| 330 | 
            +
                        num_frames // tubelet_size, 
         | 
| 331 | 
            +
                        img_size[0] // patch_size[0], 
         | 
| 332 | 
            +
                        img_size[1] // patch_size[1]
         | 
| 333 | 
            +
                    ) # (T, H, W)
         | 
| 334 | 
            +
                    self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
         | 
| 335 | 
            +
                    self.num_img_patches = self.grid_size[1] * self.grid_size[2]
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                    self.proj = nn.Conv3d(
         | 
| 338 | 
            +
                        in_channels=in_chans, out_channels=embed_dim, 
         | 
| 339 | 
            +
                        kernel_size=(tubelet_size, patch_size[0], patch_size[1]), 
         | 
| 340 | 
            +
                        stride=(tubelet_size, patch_size[0], patch_size[1])
         | 
| 341 | 
            +
                    )
         | 
| 342 | 
            +
                    self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
         | 
| 343 | 
            +
                
         | 
| 344 | 
            +
                def forward(self, x):
         | 
| 345 | 
            +
                    x = self.proj(x)
         | 
| 346 | 
            +
                    x = x.flatten(3).permute(0, 2, 3, 1)  # B x C x T x HW => B x T x HW x C
         | 
| 347 | 
            +
                    x = self.norm(x)
         | 
| 348 | 
            +
                    return x
         | 
| 349 | 
            +
             | 
| 350 | 
            +
             | 
| 351 | 
            +
            class Linear_Decoder(nn.Module):
         | 
| 352 | 
            +
                def __init__(self, in_channels=1408, out_channels=3200, 
         | 
| 353 | 
            +
                             norm_layer=nn.LayerNorm, clip_norm_type='l2'):
         | 
| 354 | 
            +
                    super().__init__()
         | 
| 355 | 
            +
                    self.clip_norm_type = clip_norm_type
         | 
| 356 | 
            +
                    # logger.info(f'Normalization Type: {clip_norm_type}')
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                    self.head = nn.Linear(in_channels, out_channels)
         | 
| 359 | 
            +
                    self.norm =  norm_layer(out_channels)
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                    self.apply(self._init_weights)
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                def _init_weights(self, m):
         | 
| 364 | 
            +
                    if isinstance(m, nn.Linear):
         | 
| 365 | 
            +
                        nn.init.xavier_uniform_(m.weight)
         | 
| 366 | 
            +
                        if isinstance(m, nn.Linear) and m.bias is not None:
         | 
| 367 | 
            +
                            nn.init.constant_(m.bias, 0)
         | 
| 368 | 
            +
                    elif isinstance(m, nn.LayerNorm):
         | 
| 369 | 
            +
                        nn.init.constant_(m.bias, 0)
         | 
| 370 | 
            +
                        nn.init.constant_(m.weight, 1.0)
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                def forward(self, x):
         | 
| 373 | 
            +
                    x = self.norm(self.head(x))
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                    if self.clip_norm_type == 'l2':
         | 
| 376 | 
            +
                        x = x / x.norm(dim=-1, keepdim=True)
         | 
| 377 | 
            +
                    elif self.clip_norm_type == 'none':
         | 
| 378 | 
            +
                        pass
         | 
| 379 | 
            +
                    else:
         | 
| 380 | 
            +
                        raise NotImplementedError
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                    return x
         | 
| 383 | 
            +
             | 
| 384 | 
            +
             | 
| 385 | 
            +
            class PretrainInternVideo2(nn.Module):
         | 
| 386 | 
            +
                def __init__(
         | 
| 387 | 
            +
                        self,
         | 
| 388 | 
            +
                        in_chans: int = 3,
         | 
| 389 | 
            +
                        patch_size: int = 14,
         | 
| 390 | 
            +
                        img_size: int = 224,
         | 
| 391 | 
            +
                        qkv_bias: bool = False,
         | 
| 392 | 
            +
                        drop_path_rate: float = 0.25,
         | 
| 393 | 
            +
                        embed_dim: int = 1408,
         | 
| 394 | 
            +
                        num_heads: int = 16,
         | 
| 395 | 
            +
                        mlp_ratio: float = 48/11,
         | 
| 396 | 
            +
                        init_values: float = 1e-5,
         | 
| 397 | 
            +
                        qk_normalization: bool = True,
         | 
| 398 | 
            +
                        depth: int = 40,
         | 
| 399 | 
            +
                        use_flash_attn: bool = False,
         | 
| 400 | 
            +
                        use_fused_rmsnorm: bool = False,
         | 
| 401 | 
            +
                        use_fused_mlp: bool = False,
         | 
| 402 | 
            +
                        fused_mlp_heuristic: int = 1,
         | 
| 403 | 
            +
                        attn_pool_num_heads: int = 16,
         | 
| 404 | 
            +
                        clip_embed_dim: int = 768,
         | 
| 405 | 
            +
                        layerscale_no_force_fp32: bool = False,
         | 
| 406 | 
            +
                        num_frames: int = 8,
         | 
| 407 | 
            +
                        tubelet_size: int = 1,
         | 
| 408 | 
            +
                        sep_pos_embed: bool = False,
         | 
| 409 | 
            +
                        sep_image_video_pos_embed: bool = False,
         | 
| 410 | 
            +
                        use_checkpoint: bool = False,
         | 
| 411 | 
            +
                        checkpoint_num: int = 0,
         | 
| 412 | 
            +
                        # for unmasked teacher
         | 
| 413 | 
            +
                        clip_teacher_embed_dim: int = 3200,
         | 
| 414 | 
            +
                        clip_teacher_final_dim: int = 768, # if 0, not distill final features
         | 
| 415 | 
            +
                        clip_norm_type: str = 'l2',
         | 
| 416 | 
            +
                        clip_return_layer: int = 1,
         | 
| 417 | 
            +
                        clip_student_return_interval: int = 1,
         | 
| 418 | 
            +
                    ):
         | 
| 419 | 
            +
                    super().__init__()
         | 
| 420 | 
            +
                    
         | 
| 421 | 
            +
                    self.num_frames = num_frames
         | 
| 422 | 
            +
                    # print(f'num_frames: {num_frames}')
         | 
| 423 | 
            +
                    self.tubelet_size = tubelet_size
         | 
| 424 | 
            +
                    assert use_flash_attn == use_fused_rmsnorm == use_fused_mlp, 'use_flash_attn, use_fused_rmsnorm and use_fused_mlp should be consistent'
         | 
| 425 | 
            +
                    
         | 
| 426 | 
            +
                    self.use_flash_attn = use_flash_attn
         | 
| 427 | 
            +
                    self.embed_dim = embed_dim
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                    self.depth = depth
         | 
| 430 | 
            +
                    self.clip_norm_type = clip_norm_type
         | 
| 431 | 
            +
                    self.return_index = []
         | 
| 432 | 
            +
                    for i in range(clip_return_layer):
         | 
| 433 | 
            +
                        self.return_index.append(depth - int(i * clip_student_return_interval) - 1)
         | 
| 434 | 
            +
                    # logger.info(f'Normalization Type: {clip_norm_type}')
         | 
| 435 | 
            +
                    # logger.info(f'Strudent Return Index: {self.return_index}')
         | 
| 436 | 
            +
                    
         | 
| 437 | 
            +
                    if use_fused_rmsnorm:
         | 
| 438 | 
            +
                        norm_layer_for_blocks = partial(DropoutAddRMSNorm, eps=1e-6, prenorm=True)
         | 
| 439 | 
            +
                    else:
         | 
| 440 | 
            +
                        norm_layer_for_blocks = partial(RMSNorm, eps=1e-6)
         | 
| 441 | 
            +
                    self.norm_layer_for_blocks = norm_layer_for_blocks
         | 
| 442 | 
            +
                    self.patch_embed = PatchEmbed(
         | 
| 443 | 
            +
                        img_size, patch_size, in_chans, embed_dim,
         | 
| 444 | 
            +
                        num_frames=num_frames, tubelet_size=tubelet_size,
         | 
| 445 | 
            +
                    )
         | 
| 446 | 
            +
                    num_patches = self.patch_embed.num_patches
         | 
| 447 | 
            +
                    num_img_patches = self.patch_embed.num_img_patches
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                    self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
         | 
| 450 | 
            +
                    
         | 
| 451 | 
            +
                    # stolen from https://github.com/facebookresearch/mae_st/blob/dc072aaaf640d06892e23a33b42223a994efe272/models_vit.py#L65-L73C17
         | 
| 452 | 
            +
                    self.sep_pos_embed = sep_pos_embed
         | 
| 453 | 
            +
                    self.sep_image_video_pos_embed = sep_image_video_pos_embed
         | 
| 454 | 
            +
                    if sep_pos_embed:
         | 
| 455 | 
            +
                        raise NotImplementedError
         | 
| 456 | 
            +
                    else:
         | 
| 457 | 
            +
                        if sep_image_video_pos_embed:
         | 
| 458 | 
            +
                            # logger.info("Use joint position embedding, for image and video we use different pos_embed.")
         | 
| 459 | 
            +
                            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
         | 
| 460 | 
            +
                            self.img_pos_embed = nn.Parameter(torch.zeros(1, num_img_patches + 1, embed_dim))
         | 
| 461 | 
            +
                            # for CLIP decoder
         | 
| 462 | 
            +
                            self.clip_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
         | 
| 463 | 
            +
                            self.clip_img_pos_embed = nn.Parameter(torch.zeros(1, num_img_patches + 1, embed_dim))
         | 
| 464 | 
            +
                        else:
         | 
| 465 | 
            +
                            # logger.info("Use joint position embedding, for image and video we use same pos_embed.")
         | 
| 466 | 
            +
                            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
         | 
| 467 | 
            +
                            self.clip_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
         | 
| 468 | 
            +
                    dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
         | 
| 469 | 
            +
                    # choose which layer to use checkpoint
         | 
| 470 | 
            +
                    with_cp_list = [False] * depth
         | 
| 471 | 
            +
                    if use_checkpoint:
         | 
| 472 | 
            +
                        for idx in range(depth):
         | 
| 473 | 
            +
                            if idx < checkpoint_num:
         | 
| 474 | 
            +
                                with_cp_list[idx] = True
         | 
| 475 | 
            +
                    # logger.info(f"Droppath rate: {dpr}")
         | 
| 476 | 
            +
                    # logger.info(f"Checkpoint list: {with_cp_list}")
         | 
| 477 | 
            +
                    
         | 
| 478 | 
            +
                    self.blocks = nn.ModuleList([
         | 
| 479 | 
            +
                        Block(embed_dim, num_heads, mlp_ratio, qkv_bias=qkv_bias,
         | 
| 480 | 
            +
                              norm_layer=norm_layer_for_blocks,
         | 
| 481 | 
            +
                              drop_path=dpr[i], init_values=init_values, attn_drop=0.,
         | 
| 482 | 
            +
                              use_flash_attn=use_flash_attn, use_fused_mlp=use_fused_mlp,
         | 
| 483 | 
            +
                              fused_mlp_heuristic=fused_mlp_heuristic,
         | 
| 484 | 
            +
                              with_cp=with_cp_list[i],
         | 
| 485 | 
            +
                              qk_normalization=qk_normalization,
         | 
| 486 | 
            +
                              layerscale_no_force_fp32=layerscale_no_force_fp32,
         | 
| 487 | 
            +
                              use_fused_rmsnorm=use_fused_rmsnorm)
         | 
| 488 | 
            +
                        for i in range(depth)])
         | 
| 489 | 
            +
                    self.clip_projector = AttentionPoolingBlock(
         | 
| 490 | 
            +
                        dim=embed_dim, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
         | 
| 491 | 
            +
                        drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim)
         | 
| 492 | 
            +
                    
         | 
| 493 | 
            +
                    # CLIP decoder
         | 
| 494 | 
            +
                    self.clip_decoder = nn.ModuleList([
         | 
| 495 | 
            +
                        Linear_Decoder(
         | 
| 496 | 
            +
                            in_channels=embed_dim, 
         | 
| 497 | 
            +
                            out_channels=clip_teacher_embed_dim, 
         | 
| 498 | 
            +
                            norm_layer=partial(nn.LayerNorm, eps=1e-5), 
         | 
| 499 | 
            +
                            clip_norm_type=clip_norm_type
         | 
| 500 | 
            +
                        ) for _ in range(clip_return_layer)
         | 
| 501 | 
            +
                    ])
         | 
| 502 | 
            +
                    self.final_clip_decoder = nn.Identity()
         | 
| 503 | 
            +
                    if clip_teacher_final_dim > 0:
         | 
| 504 | 
            +
                        self.final_clip_decoder = Linear_Decoder(
         | 
| 505 | 
            +
                            in_channels=clip_embed_dim, 
         | 
| 506 | 
            +
                            out_channels=clip_teacher_final_dim, 
         | 
| 507 | 
            +
                            norm_layer=partial(nn.LayerNorm, eps=1e-5), 
         | 
| 508 | 
            +
                            clip_norm_type=clip_norm_type
         | 
| 509 | 
            +
                        )
         | 
| 510 | 
            +
                    
         | 
| 511 | 
            +
                    self.init_pos_embed()
         | 
| 512 | 
            +
                    trunc_normal_(self.cls_token, std=.02)
         | 
| 513 | 
            +
                    self.apply(self._init_weights)
         | 
| 514 | 
            +
                    self.fix_init_weight()
         | 
| 515 | 
            +
             | 
| 516 | 
            +
                def init_pos_embed(self):
         | 
| 517 | 
            +
                    # logger.info("Init pos_embed from sincos pos_embed")
         | 
| 518 | 
            +
                    if self.sep_pos_embed:
         | 
| 519 | 
            +
                        raise NotImplementedError
         | 
| 520 | 
            +
                    else:
         | 
| 521 | 
            +
                        # trunc_normal_(self.pos_embed, std=.02)
         | 
| 522 | 
            +
                        # trunc_normal_(self.clip_pos_embed, std=.02)
         | 
| 523 | 
            +
                        pos_embed = get_3d_sincos_pos_embed(
         | 
| 524 | 
            +
                            self.pos_embed.shape[-1], 
         | 
| 525 | 
            +
                            self.patch_embed.grid_size[1], # height & weight
         | 
| 526 | 
            +
                            self.patch_embed.grid_size[0], # t_size
         | 
| 527 | 
            +
                            cls_token=True
         | 
| 528 | 
            +
                        )
         | 
| 529 | 
            +
                        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
         | 
| 530 | 
            +
                        self.clip_pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
         | 
| 531 | 
            +
                        
         | 
| 532 | 
            +
                        if self.sep_image_video_pos_embed:
         | 
| 533 | 
            +
                            img_pos_embed = get_3d_sincos_pos_embed(
         | 
| 534 | 
            +
                                self.pos_embed.shape[-1], 
         | 
| 535 | 
            +
                                self.patch_embed.grid_size[1], # height & weight
         | 
| 536 | 
            +
                                1,
         | 
| 537 | 
            +
                                cls_token=True
         | 
| 538 | 
            +
                            )
         | 
| 539 | 
            +
                            self.img_pos_embed.data.copy_(torch.from_numpy(img_pos_embed).float().unsqueeze(0))
         | 
| 540 | 
            +
                            self.clip_img_pos_embed.data.copy_(torch.from_numpy(img_pos_embed).float().unsqueeze(0))
         | 
| 541 | 
            +
             | 
| 542 | 
            +
                def _init_weights(self, m):
         | 
| 543 | 
            +
                    if isinstance(m, nn.Linear):
         | 
| 544 | 
            +
                        trunc_normal_(m.weight, std=.02)
         | 
| 545 | 
            +
                        if isinstance(m, nn.Linear) and m.bias is not None:
         | 
| 546 | 
            +
                            nn.init.constant_(m.bias, 0)
         | 
| 547 | 
            +
                    elif isinstance(m, nn.LayerNorm):
         | 
| 548 | 
            +
                        nn.init.constant_(m.bias, 0)
         | 
| 549 | 
            +
                        nn.init.constant_(m.weight, 1.0)
         | 
| 550 | 
            +
             | 
| 551 | 
            +
                def fix_init_weight(self):
         | 
| 552 | 
            +
                    def rescale(param, layer_id):
         | 
| 553 | 
            +
                        param.div_(math.sqrt(2.0 * layer_id))
         | 
| 554 | 
            +
             | 
| 555 | 
            +
                    for layer_id, layer in enumerate(self.blocks):
         | 
| 556 | 
            +
                        rescale(layer.attn.proj.weight.data, layer_id + 1)
         | 
| 557 | 
            +
                        rescale(layer.mlp.fc2.weight.data, layer_id + 1)
         | 
| 558 | 
            +
                
         | 
| 559 | 
            +
                @property
         | 
| 560 | 
            +
                def dtype(self):
         | 
| 561 | 
            +
                    return self.patch_embed.proj.weight.dtype
         | 
| 562 | 
            +
             | 
| 563 | 
            +
                def get_num_layers(self):
         | 
| 564 | 
            +
                    return len(self.blocks)
         | 
| 565 | 
            +
             | 
| 566 | 
            +
                @torch.jit.ignore
         | 
| 567 | 
            +
                def no_weight_decay(self):
         | 
| 568 | 
            +
                    return {
         | 
| 569 | 
            +
                        'pos_embed', 
         | 
| 570 | 
            +
                        'pos_embed_spatial', 
         | 
| 571 | 
            +
                        'pos_embed_temporal', 
         | 
| 572 | 
            +
                        'pos_embed_cls',
         | 
| 573 | 
            +
                        'img_pos_embed',
         | 
| 574 | 
            +
                        'cls_token',
         | 
| 575 | 
            +
                        'clip_pos_embed', 
         | 
| 576 | 
            +
                        'clip_pos_embed_spatial', 
         | 
| 577 | 
            +
                        'clip_pos_embed_temporal', 
         | 
| 578 | 
            +
                        'clip_pos_embed_cls',
         | 
| 579 | 
            +
                        'clip_img_pos_embed'
         | 
| 580 | 
            +
                    }
         | 
| 581 | 
            +
                
         | 
| 582 | 
            +
                # @torch.cuda.amp.autocast(enabled=False)
         | 
| 583 | 
            +
                def forward(self, x, mask=None, use_image=False, x_vis_return_idx=-1, x_vis_only=False):
         | 
| 584 | 
            +
                    # print(0, x.shape)
         | 
| 585 | 
            +
                    x = self.patch_embed(x.type(self.dtype))
         | 
| 586 | 
            +
                    # print(f"x.shape: {x.shape} x.dtype: {x.dtype}, model.dtype: {self.dtype}")
         | 
| 587 | 
            +
                    B, T, L, C = x.shape  # T: temporal; L: spatial
         | 
| 588 | 
            +
                    x = x.view([B, T * L, C])   # (B, T * L, C)
         | 
| 589 | 
            +
             | 
| 590 | 
            +
                    # append cls token
         | 
| 591 | 
            +
                    cls_tokens = self.cls_token.expand(B, -1, -1)
         | 
| 592 | 
            +
                    x = torch.cat((cls_tokens, x), dim=1)   # (B, T * L + 1, C)
         | 
| 593 | 
            +
                    # print(1, x.shape)
         | 
| 594 | 
            +
             | 
| 595 | 
            +
                    # add pos_embed
         | 
| 596 | 
            +
                    if self.sep_pos_embed:
         | 
| 597 | 
            +
                        raise NotImplementedError
         | 
| 598 | 
            +
                    else:
         | 
| 599 | 
            +
                        if use_image:
         | 
| 600 | 
            +
                            # print('use image')  # No.
         | 
| 601 | 
            +
                            if self.sep_image_video_pos_embed:
         | 
| 602 | 
            +
                                pos_embed = self.img_pos_embed
         | 
| 603 | 
            +
                            else:
         | 
| 604 | 
            +
                                # (1, num_img_patches + 1, embed_dim)
         | 
| 605 | 
            +
                                # print('origin pos_embed.shape:', self.pos_embed.shape)
         | 
| 606 | 
            +
                                cls_pos_embed = self.pos_embed[:, 0:1, :]
         | 
| 607 | 
            +
                                # print('cls_pos_embed.shape:', cls_pos_embed.shape)
         | 
| 608 | 
            +
             | 
| 609 | 
            +
                                img_pos_embed = self.pos_embed[:, 1:, :].view(1, self.num_frames, self.patch_embed.num_patches // self.num_frames, self.embed_dim).mean(dim=1)
         | 
| 610 | 
            +
                                # print('img_pos_embed.shape:', img_pos_embed.shape)
         | 
| 611 | 
            +
             | 
| 612 | 
            +
                                pos_embed = torch.cat([cls_pos_embed, img_pos_embed], dim=1)
         | 
| 613 | 
            +
                                # print('final img_pos_embed.shape:', pos_embed.shape)
         | 
| 614 | 
            +
                        else:
         | 
| 615 | 
            +
                            pos_embed = self.pos_embed
         | 
| 616 | 
            +
                    pos_embed = pos_embed[:, :x.shape[1], :]
         | 
| 617 | 
            +
                    x = x + pos_embed
         | 
| 618 | 
            +
             | 
| 619 | 
            +
                    # mask tokens, ~mask means visible
         | 
| 620 | 
            +
                    if mask is not None:
         | 
| 621 | 
            +
                        x = x[~mask].reshape(B, -1, C) 
         | 
| 622 | 
            +
                    else:
         | 
| 623 | 
            +
                        x = x.reshape(B, -1, C) 
         | 
| 624 | 
            +
                    residual = None
         | 
| 625 | 
            +
                    x_clip = []
         | 
| 626 | 
            +
                    for idx, blk in enumerate(self.blocks):
         | 
| 627 | 
            +
                        if isinstance(x, tuple) and len(x) == 2:
         | 
| 628 | 
            +
                            x, residual = x
         | 
| 629 | 
            +
                        # print(f"\033[31m这是{idx}, {x.shape}\033[0m")
         | 
| 630 | 
            +
                        x = blk(x, residual=residual)
         | 
| 631 | 
            +
                        # return intermediate features
         | 
| 632 | 
            +
                        if idx in self.return_index:
         | 
| 633 | 
            +
                            if isinstance(x, tuple) and len(x) == 2:
         | 
| 634 | 
            +
                                tmp_x, tmp_residual = x
         | 
| 635 | 
            +
                                if residual is not None:
         | 
| 636 | 
            +
                                    x_clip.append(tmp_x + tmp_residual)
         | 
| 637 | 
            +
                            else:
         | 
| 638 | 
            +
                                x_clip.append(x)
         | 
| 639 | 
            +
                        if idx == (self.depth + x_vis_return_idx):
         | 
| 640 | 
            +
                            # print(f'idx = {idx} len(self.blocks)={len(self.blocks)}')
         | 
| 641 | 
            +
                            break
         | 
| 642 | 
            +
                    
         | 
| 643 | 
            +
                    if isinstance(x, tuple) and len(x) == 2:
         | 
| 644 | 
            +
                        x, residual = x
         | 
| 645 | 
            +
                        if residual is not None:
         | 
| 646 | 
            +
                            x = x + residual
         | 
| 647 | 
            +
                    
         | 
| 648 | 
            +
                    x_vis = x
         | 
| 649 | 
            +
                    # print(f'x_vis.shape:{x_vis.shape}')
         | 
| 650 | 
            +
                    if x_vis_only:
         | 
| 651 | 
            +
                        return x_vis
         | 
| 652 | 
            +
                    
         | 
| 653 | 
            +
                    x_pool_vis = self.clip_projector(x_vis) 
         | 
| 654 | 
            +
                    x_align = self.final_clip_decoder(x_pool_vis)
         | 
| 655 | 
            +
                    # print(3, x_pool_vis.shape)
         | 
| 656 | 
            +
                    # print(4, x_align.shape)
         | 
| 657 | 
            +
             | 
| 658 | 
            +
                    # align CLIP
         | 
| 659 | 
            +
                    x_clip = torch.stack(x_clip)
         | 
| 660 | 
            +
                    K, B, _, C_CLIP = x_clip.shape
         | 
| 661 | 
            +
                    # print(5, x_clip.shape)
         | 
| 662 | 
            +
                    # add pos_embed
         | 
| 663 | 
            +
                    if self.sep_pos_embed: 
         | 
| 664 | 
            +
                        raise NotImplementedError
         | 
| 665 | 
            +
                    else:
         | 
| 666 | 
            +
                        if use_image:
         | 
| 667 | 
            +
                            if self.sep_image_video_pos_embed:
         | 
| 668 | 
            +
                                clip_pos_embed = self.clip_img_pos_embed
         | 
| 669 | 
            +
                            else:
         | 
| 670 | 
            +
                                # (1, num_img_patches + 1, embed_dim)
         | 
| 671 | 
            +
                                # print('origin pos_embed.shape:', self.pos_embed.shape)
         | 
| 672 | 
            +
                                clip_cls_pos_embed = self.clip_pos_embed[:, 0:1, :]
         | 
| 673 | 
            +
                                # print('cls_pos_embed.shape:', cls_pos_embed.shape)
         | 
| 674 | 
            +
             | 
| 675 | 
            +
                                clip_img_pos_embed = self.clip_pos_embed[:, 1:, :].view(1, self.num_frames, self.patch_embed.num_patches // self.num_frames, self.embed_dim).mean(dim=1)
         | 
| 676 | 
            +
                                # print('img_pos_embed.shape:', img_pos_embed.shape)
         | 
| 677 | 
            +
             | 
| 678 | 
            +
                                clip_pos_embed = torch.cat([clip_cls_pos_embed, clip_img_pos_embed], dim=1)
         | 
| 679 | 
            +
                                # print('final img_pos_embed.shape:', pos_embed.shape)
         | 
| 680 | 
            +
             | 
| 681 | 
            +
                        else:
         | 
| 682 | 
            +
                            clip_pos_embed = self.clip_pos_embed
         | 
| 683 | 
            +
                    
         | 
| 684 | 
            +
                    clip_pos_embed = clip_pos_embed.repeat(B, 1, 1)
         | 
| 685 | 
            +
                    if mask is not None:
         | 
| 686 | 
            +
                        x_clip = x_clip + clip_pos_embed[~mask].view(B, -1, C_CLIP).unsqueeze(0).repeat(K, 1, 1, 1)
         | 
| 687 | 
            +
                    else:
         | 
| 688 | 
            +
                        clip_pos_embed = clip_pos_embed.unsqueeze(0).repeat(K, 1, 1, 1)
         | 
| 689 | 
            +
                        clip_pos_embed = clip_pos_embed[:, :, :x_clip.shape[2], :]
         | 
| 690 | 
            +
                        x_clip = x_clip + clip_pos_embed
         | 
| 691 | 
            +
                    
         | 
| 692 | 
            +
                    # CLIP decoder
         | 
| 693 | 
            +
                    x_clip_align = []
         | 
| 694 | 
            +
                    for idx, clip_decoder in enumerate(self.clip_decoder):
         | 
| 695 | 
            +
                        x_clip_align.append(clip_decoder(x_clip[idx]))
         | 
| 696 | 
            +
                    x_clip_align = torch.stack(x_clip_align)
         | 
| 697 | 
            +
                    
         | 
| 698 | 
            +
                    # print(f'x_vis.shape:{x_vis.shape}, x_pool_vis.shape:{x_pool_vis.shape}')
         | 
| 699 | 
            +
                    return x_vis, x_pool_vis, x_clip_align, x_align
         | 
| 700 | 
            +
                
         | 
| 701 | 
            +
             | 
| 702 | 
            +
            def pretrain_internvideo2_1b_patch14_224(config):
         | 
| 703 | 
            +
                # print(config.vision_encoder.num_frames)
         | 
| 704 | 
            +
                model = PretrainInternVideo2(
         | 
| 705 | 
            +
                    in_chans=3, img_size=224, patch_size=14,
         | 
| 706 | 
            +
                    embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11,
         | 
| 707 | 
            +
                    clip_embed_dim=config.vision_encoder.clip_embed_dim,
         | 
| 708 | 
            +
                    attn_pool_num_heads=16, qkv_bias=False,
         | 
| 709 | 
            +
                    drop_path_rate=0.25,
         | 
| 710 | 
            +
                    init_values=0.00001,
         | 
| 711 | 
            +
                    qk_normalization=True,
         | 
| 712 | 
            +
                    use_flash_attn=config.vision_encoder.get('use_flash_attn', True),
         | 
| 713 | 
            +
                    use_fused_rmsnorm=config.vision_encoder.get('use_fused_rmsnorm', True),
         | 
| 714 | 
            +
                    use_fused_mlp=config.vision_encoder.get('use_fused_mlp', True),
         | 
| 715 | 
            +
                    fused_mlp_heuristic=1,
         | 
| 716 | 
            +
                    layerscale_no_force_fp32=False,
         | 
| 717 | 
            +
                    num_frames=config.vision_encoder.num_frames,
         | 
| 718 | 
            +
                    tubelet_size=config.vision_encoder.tubelet_size,
         | 
| 719 | 
            +
                    sep_pos_embed=False,
         | 
| 720 | 
            +
                    sep_image_video_pos_embed=config.vision_encoder.sep_image_video_pos_embed,
         | 
| 721 | 
            +
                    use_checkpoint=config.vision_encoder.use_checkpoint,
         | 
| 722 | 
            +
                    checkpoint_num=config.vision_encoder.checkpoint_num,
         | 
| 723 | 
            +
                    clip_teacher_embed_dim=config.vision_encoder.clip_teacher_embed_dim,
         | 
| 724 | 
            +
                    clip_teacher_final_dim=config.vision_encoder.clip_teacher_final_dim,
         | 
| 725 | 
            +
                    clip_norm_type=config.vision_encoder.clip_norm_type,
         | 
| 726 | 
            +
                    clip_return_layer=config.vision_encoder.clip_return_layer,
         | 
| 727 | 
            +
                    clip_student_return_interval=config.vision_encoder.clip_student_return_interval,
         | 
| 728 | 
            +
                )
         | 
| 729 | 
            +
             | 
| 730 | 
            +
                if config.vision_encoder.pretrained is not None:
         | 
| 731 | 
            +
                    # logger.info(f"Loading pretrained weights from {config.vision_encoder.pretrained}")
         | 
| 732 | 
            +
                    state_dict = torch.load(config.vision_encoder.pretrained, map_location='cpu')
         | 
| 733 | 
            +
                    interpolate_pos_embed_internvideo2(state_dict, model, orig_t_size=8)
         | 
| 734 | 
            +
                    message = model.load_state_dict(state_dict, strict=False)
         | 
| 735 | 
            +
                    # logger.info(message)
         | 
| 736 | 
            +
                else:
         | 
| 737 | 
            +
                    pass
         | 
| 738 | 
            +
                    # logger.info("No pretrained weights!!!")
         | 
| 739 | 
            +
                return model
         | 
| 740 | 
            +
             | 
| 741 | 
            +
             | 
| 742 | 
            +
             | 
| 743 | 
            +
            def pretrain_internvideo2_6b_patch14_224(config):
         | 
| 744 | 
            +
                model = PretrainInternVideo2(
         | 
| 745 | 
            +
                    in_chans=3, img_size=224, patch_size=14,
         | 
| 746 | 
            +
                    embed_dim=3200, depth=48, num_heads=25, mlp_ratio=4,
         | 
| 747 | 
            +
                    clip_embed_dim=config.vision_encoder.clip_embed_dim,
         | 
| 748 | 
            +
                    attn_pool_num_heads=16, qkv_bias=False,
         | 
| 749 | 
            +
                    drop_path_rate=0.3,
         | 
| 750 | 
            +
                    init_values=0.00001,
         | 
| 751 | 
            +
                    qk_normalization=True,
         | 
| 752 | 
            +
                    use_flash_attn=config.vision_encoder.get('use_flash_attn', True),
         | 
| 753 | 
            +
                    use_fused_rmsnorm=config.vision_encoder.get('use_fused_rmsnorm', True),
         | 
| 754 | 
            +
                    use_fused_mlp=config.vision_encoder.get('use_fused_mlp', True),
         | 
| 755 | 
            +
                    fused_mlp_heuristic=1,
         | 
| 756 | 
            +
                    layerscale_no_force_fp32=False,
         | 
| 757 | 
            +
                    num_frames=config.vision_encoder.num_frames,
         | 
| 758 | 
            +
                    tubelet_size=config.vision_encoder.tubelet_size,
         | 
| 759 | 
            +
                    sep_pos_embed=False,
         | 
| 760 | 
            +
                    sep_image_video_pos_embed=config.vision_encoder.sep_image_video_pos_embed,
         | 
| 761 | 
            +
                    use_checkpoint=config.vision_encoder.use_checkpoint,
         | 
| 762 | 
            +
                    checkpoint_num=config.vision_encoder.checkpoint_num,
         | 
| 763 | 
            +
                    clip_teacher_embed_dim=config.vision_encoder.clip_teacher_embed_dim,
         | 
| 764 | 
            +
                    clip_teacher_final_dim=config.vision_encoder.clip_teacher_final_dim,
         | 
| 765 | 
            +
                    clip_norm_type=config.vision_encoder.clip_norm_type,
         | 
| 766 | 
            +
                    clip_return_layer=config.vision_encoder.clip_return_layer,
         | 
| 767 | 
            +
                    clip_student_return_interval=config.vision_encoder.clip_student_return_interval,
         | 
| 768 | 
            +
                )
         | 
| 769 | 
            +
             | 
| 770 | 
            +
                if config.vision_encoder.pretrained is not None:
         | 
| 771 | 
            +
                    # logger.info(f"Loading pretrained weights from {config.vision_encoder.pretrained}")
         | 
| 772 | 
            +
                    state_dict = torch.load(config.vision_encoder.pretrained, map_location='cpu')
         | 
| 773 | 
            +
                    interpolate_pos_embed_internvideo2(state_dict, model, orig_t_size=8)
         | 
| 774 | 
            +
                    msg = model.load_state_dict(state_dict, strict=False)
         | 
| 775 | 
            +
                    # logger.info(msg)
         | 
| 776 | 
            +
                else:
         | 
| 777 | 
            +
                    pass
         | 
| 778 | 
            +
                    # logger.info("No pretrained weights!!!")
         | 
| 779 | 
            +
                return model
         | 
    	
        internvideo2_clip_vision.py
    ADDED
    
    | @@ -0,0 +1,553 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import logging
         | 
| 2 | 
            +
            import math
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn.functional as F
         | 
| 5 | 
            +
            from timm.models.layers import DropPath, to_2tuple, trunc_normal_
         | 
| 6 | 
            +
            from timm.models.registry import register_model
         | 
| 7 | 
            +
            from torch import nn
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch.utils.checkpoint as checkpoint
         | 
| 10 | 
            +
            from functools import partial
         | 
| 11 | 
            +
            from einops import rearrange
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from .pos_embed import get_3d_sincos_pos_embed, get_2d_sincos_pos_embed, get_1d_sincos_pos_embed
         | 
| 14 | 
            +
            from .flash_attention_class import FlashAttention
         | 
| 15 | 
            +
            from flash_attn.modules.mlp import FusedMLP
         | 
| 16 | 
            +
            try:
         | 
| 17 | 
            +
                from flash_attn.ops.rms_norm import DropoutAddRMSNorm
         | 
| 18 | 
            +
            except:
         | 
| 19 | 
            +
                pass
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            from transformers.utils import logging
         | 
| 22 | 
            +
            import warnings
         | 
| 23 | 
            +
            warnings.filterwarnings("ignore")
         | 
| 24 | 
            +
            logger = logging.get_logger(__name__)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            class CrossAttention(nn.Module):
         | 
| 28 | 
            +
                def __init__(
         | 
| 29 | 
            +
                        self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
         | 
| 30 | 
            +
                        proj_drop=0., attn_head_dim=None, out_dim=None):
         | 
| 31 | 
            +
                    super().__init__()
         | 
| 32 | 
            +
                    if out_dim is None:
         | 
| 33 | 
            +
                        out_dim = dim
         | 
| 34 | 
            +
                    self.num_heads = num_heads
         | 
| 35 | 
            +
                    head_dim = dim // num_heads
         | 
| 36 | 
            +
                    if attn_head_dim is not None:
         | 
| 37 | 
            +
                        head_dim = attn_head_dim
         | 
| 38 | 
            +
                    all_head_dim = head_dim * self.num_heads
         | 
| 39 | 
            +
                    self.scale = qk_scale or head_dim ** -0.5
         | 
| 40 | 
            +
                    assert all_head_dim == dim
         | 
| 41 | 
            +
                    
         | 
| 42 | 
            +
                    self.q = nn.Linear(dim, all_head_dim, bias=False)
         | 
| 43 | 
            +
                    self.k = nn.Linear(dim, all_head_dim, bias=False)
         | 
| 44 | 
            +
                    self.v = nn.Linear(dim, all_head_dim, bias=False)
         | 
| 45 | 
            +
                    
         | 
| 46 | 
            +
                    if qkv_bias:
         | 
| 47 | 
            +
                        self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
         | 
| 48 | 
            +
                        self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
         | 
| 49 | 
            +
                        self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
         | 
| 50 | 
            +
                    else:
         | 
| 51 | 
            +
                        self.q_bias = None
         | 
| 52 | 
            +
                        self.k_bias = None
         | 
| 53 | 
            +
                        self.v_bias = None
         | 
| 54 | 
            +
                    
         | 
| 55 | 
            +
                    self.attn_drop = nn.Dropout(attn_drop)
         | 
| 56 | 
            +
                    self.proj = nn.Linear(all_head_dim, out_dim)
         | 
| 57 | 
            +
                    self.proj_drop = nn.Dropout(proj_drop)
         | 
| 58 | 
            +
                
         | 
| 59 | 
            +
                def forward(self, x, k=None, v=None):
         | 
| 60 | 
            +
                    B, N, C = x.shape
         | 
| 61 | 
            +
                    N_k = k.shape[1]
         | 
| 62 | 
            +
                    N_v = v.shape[1]
         | 
| 63 | 
            +
                    
         | 
| 64 | 
            +
                    q_bias, k_bias, v_bias = None, None, None
         | 
| 65 | 
            +
                    if self.q_bias is not None:
         | 
| 66 | 
            +
                        q_bias = self.q_bias
         | 
| 67 | 
            +
                        k_bias = self.k_bias
         | 
| 68 | 
            +
                        v_bias = self.v_bias
         | 
| 69 | 
            +
                    
         | 
| 70 | 
            +
                    q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
         | 
| 71 | 
            +
                    q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)  # (B, N_head, N_q, dim)
         | 
| 72 | 
            +
                    
         | 
| 73 | 
            +
                    k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
         | 
| 74 | 
            +
                    k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
         | 
| 75 | 
            +
                    
         | 
| 76 | 
            +
                    v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
         | 
| 77 | 
            +
                    v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
         | 
| 78 | 
            +
                    
         | 
| 79 | 
            +
                    q = q * self.scale
         | 
| 80 | 
            +
                    attn = (q @ k.transpose(-2, -1))  # (B, N_head, N_q, N_k)
         | 
| 81 | 
            +
                    
         | 
| 82 | 
            +
                    attn = attn.softmax(dim=-1)
         | 
| 83 | 
            +
                    attn = self.attn_drop(attn)
         | 
| 84 | 
            +
                    
         | 
| 85 | 
            +
                    x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
         | 
| 86 | 
            +
                    x = self.proj(x)
         | 
| 87 | 
            +
                    x = self.proj_drop(x)
         | 
| 88 | 
            +
                    
         | 
| 89 | 
            +
                    return x
         | 
| 90 | 
            +
             | 
| 91 | 
            +
             | 
| 92 | 
            +
            class AttentiveBlock(nn.Module):
         | 
| 93 | 
            +
                
         | 
| 94 | 
            +
                def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
         | 
| 95 | 
            +
                             drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None):
         | 
| 96 | 
            +
                    super().__init__()
         | 
| 97 | 
            +
                    
         | 
| 98 | 
            +
                    self.norm1_q = norm_layer(dim)
         | 
| 99 | 
            +
                    self.norm1_k = norm_layer(dim)
         | 
| 100 | 
            +
                    self.norm1_v = norm_layer(dim)
         | 
| 101 | 
            +
                    self.cross_attn = CrossAttention(
         | 
| 102 | 
            +
                        dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
         | 
| 103 | 
            +
                        proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim)
         | 
| 104 | 
            +
                    
         | 
| 105 | 
            +
                    if drop_path > 0.:
         | 
| 106 | 
            +
                        logger.info(f"Use DropPath in projector: {drop_path}")
         | 
| 107 | 
            +
                    self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
         | 
| 108 | 
            +
                
         | 
| 109 | 
            +
                def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None):
         | 
| 110 | 
            +
                    x_q = self.norm1_q(x_q + pos_q)
         | 
| 111 | 
            +
                    x_k = self.norm1_k(x_kv + pos_k)
         | 
| 112 | 
            +
                    x_v = self.norm1_v(x_kv)
         | 
| 113 | 
            +
                    x = self.cross_attn(x_q, k=x_k, v=x_v)
         | 
| 114 | 
            +
                    
         | 
| 115 | 
            +
                    return x
         | 
| 116 | 
            +
             | 
| 117 | 
            +
             | 
| 118 | 
            +
            class AttentionPoolingBlock(AttentiveBlock):
         | 
| 119 | 
            +
                
         | 
| 120 | 
            +
                def forward(self, x):
         | 
| 121 | 
            +
                    x_q = x.mean(1, keepdim=True)
         | 
| 122 | 
            +
                    x_kv, pos_q, pos_k = x, 0, 0
         | 
| 123 | 
            +
                    x = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None)
         | 
| 124 | 
            +
                    x = x.squeeze(1)
         | 
| 125 | 
            +
                    return x
         | 
| 126 | 
            +
             | 
| 127 | 
            +
             | 
| 128 | 
            +
            class RMSNorm(nn.Module):
         | 
| 129 | 
            +
                def __init__(self, hidden_size, eps=1e-6):
         | 
| 130 | 
            +
                    super().__init__()
         | 
| 131 | 
            +
                    self.weight = nn.Parameter(torch.ones(hidden_size))
         | 
| 132 | 
            +
                    self.variance_epsilon = eps
         | 
| 133 | 
            +
                
         | 
| 134 | 
            +
                def forward(self, hidden_states):
         | 
| 135 | 
            +
                    input_dtype = hidden_states.dtype
         | 
| 136 | 
            +
                    hidden_states = hidden_states.to(torch.float32)
         | 
| 137 | 
            +
                    variance = hidden_states.pow(2).mean(-1, keepdim=True)
         | 
| 138 | 
            +
                    hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
         | 
| 139 | 
            +
                    return self.weight * hidden_states.to(input_dtype)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
             | 
| 142 | 
            +
            class LayerScale(nn.Module):
         | 
| 143 | 
            +
                def __init__(self, dim, init_values=1e-5, inplace=False, force_fp32=False):
         | 
| 144 | 
            +
                    super().__init__()
         | 
| 145 | 
            +
                    self.inplace = inplace
         | 
| 146 | 
            +
                    self.gamma = nn.Parameter(init_values * torch.ones(dim))
         | 
| 147 | 
            +
                    self.force_fp32 = force_fp32
         | 
| 148 | 
            +
                
         | 
| 149 | 
            +
                @torch.cuda.amp.autocast(enabled=False)
         | 
| 150 | 
            +
                def forward(self, x):
         | 
| 151 | 
            +
                    if self.force_fp32:
         | 
| 152 | 
            +
                        output_type = x.dtype
         | 
| 153 | 
            +
                        out = x.float().mul_(self.gamma.float()) if self.inplace else x.float() * self.gamma.float()
         | 
| 154 | 
            +
                        return out.to(dtype=output_type)
         | 
| 155 | 
            +
                    else:
         | 
| 156 | 
            +
                        out = x.mul_(self.gamma) if self.inplace else x * self.gamma
         | 
| 157 | 
            +
                        return out
         | 
| 158 | 
            +
             | 
| 159 | 
            +
             | 
| 160 | 
            +
            class Attention(nn.Module):
         | 
| 161 | 
            +
                def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_flash_attn=False,
         | 
| 162 | 
            +
                             causal=False, norm_layer=nn.LayerNorm, qk_normalization=False, use_fused_rmsnorm=False):
         | 
| 163 | 
            +
                    super().__init__()
         | 
| 164 | 
            +
                    assert dim % num_heads == 0, 'dim should be divisible by num_heads'
         | 
| 165 | 
            +
                    self.num_heads = num_heads
         | 
| 166 | 
            +
                    head_dim = dim // num_heads
         | 
| 167 | 
            +
                    self.scale = head_dim ** -0.5
         | 
| 168 | 
            +
                    
         | 
| 169 | 
            +
                    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
         | 
| 170 | 
            +
                    self.attn_drop = nn.Dropout(attn_drop)
         | 
| 171 | 
            +
                    self.proj = nn.Linear(dim, dim)
         | 
| 172 | 
            +
                    self.proj_drop = nn.Dropout(proj_drop)
         | 
| 173 | 
            +
                    
         | 
| 174 | 
            +
                    self.use_flash_attn = use_flash_attn
         | 
| 175 | 
            +
                    if use_flash_attn:
         | 
| 176 | 
            +
                        self.causal = causal
         | 
| 177 | 
            +
                        self.inner_attn = FlashAttention(attention_dropout=attn_drop)
         | 
| 178 | 
            +
                    
         | 
| 179 | 
            +
                    self.qk_normalization = qk_normalization
         | 
| 180 | 
            +
                    self.q_norm = norm_layer(dim) if qk_normalization else nn.Identity()
         | 
| 181 | 
            +
                    self.k_norm = norm_layer(dim) if qk_normalization else nn.Identity()
         | 
| 182 | 
            +
                    self.use_fused_rmsnorm = use_fused_rmsnorm
         | 
| 183 | 
            +
                
         | 
| 184 | 
            +
                def _naive_attn(self, x):
         | 
| 185 | 
            +
                    B, N, C = x.shape
         | 
| 186 | 
            +
                    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
         | 
| 187 | 
            +
                    q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)
         | 
| 188 | 
            +
                    
         | 
| 189 | 
            +
                    if self.qk_normalization:
         | 
| 190 | 
            +
                        B_, H_, N_, D_ = q.shape
         | 
| 191 | 
            +
                        q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
         | 
| 192 | 
            +
                        k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
         | 
| 193 | 
            +
                    
         | 
| 194 | 
            +
                    attn = ((q * self.scale) @ k.transpose(-2, -1))
         | 
| 195 | 
            +
                    # attn = attn - attn.max(-1)[0].unsqueeze(-1)  # in case of overflow for fp16
         | 
| 196 | 
            +
                    attn = attn.softmax(dim=-1)
         | 
| 197 | 
            +
                    attn = self.attn_drop(attn)
         | 
| 198 | 
            +
                    
         | 
| 199 | 
            +
                    x = (attn @ v).transpose(1, 2).reshape(B, N, C)
         | 
| 200 | 
            +
                    x = self.proj(x)
         | 
| 201 | 
            +
                    x = self.proj_drop(x)
         | 
| 202 | 
            +
                    return x
         | 
| 203 | 
            +
                
         | 
| 204 | 
            +
                def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
         | 
| 205 | 
            +
                    
         | 
| 206 | 
            +
                    qkv = self.qkv(x)
         | 
| 207 | 
            +
                    qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
         | 
| 208 | 
            +
                    
         | 
| 209 | 
            +
                    if self.qk_normalization:
         | 
| 210 | 
            +
                        q, k, v = qkv.unbind(2)
         | 
| 211 | 
            +
                        if self.use_fused_rmsnorm:
         | 
| 212 | 
            +
                            q = self.q_norm(q.flatten(-2, -1))[0].view(q.shape)
         | 
| 213 | 
            +
                            k = self.k_norm(k.flatten(-2, -1))[0].view(k.shape)
         | 
| 214 | 
            +
                        else:
         | 
| 215 | 
            +
                            q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
         | 
| 216 | 
            +
                            k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
         | 
| 217 | 
            +
                        qkv = torch.stack([q, k, v], dim=2)
         | 
| 218 | 
            +
                    
         | 
| 219 | 
            +
                    context, _ = self.inner_attn(
         | 
| 220 | 
            +
                        qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal
         | 
| 221 | 
            +
                    )
         | 
| 222 | 
            +
                    outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
         | 
| 223 | 
            +
                    outs = self.proj_drop(outs)
         | 
| 224 | 
            +
                    return outs
         | 
| 225 | 
            +
                
         | 
| 226 | 
            +
                def forward(self, x):
         | 
| 227 | 
            +
                    x = self._naive_attn(x) if not self.use_flash_attn else self._flash_attn(x)
         | 
| 228 | 
            +
                    return x
         | 
| 229 | 
            +
             | 
| 230 | 
            +
             | 
| 231 | 
            +
            class Mlp(nn.Module):
         | 
| 232 | 
            +
                """ MLP as used in Vision Transformer, MLP-Mixer and related networks
         | 
| 233 | 
            +
                """
         | 
| 234 | 
            +
                
         | 
| 235 | 
            +
                def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
         | 
| 236 | 
            +
                             bias=True, drop=0.):
         | 
| 237 | 
            +
                    super().__init__()
         | 
| 238 | 
            +
                    out_features = out_features or in_features
         | 
| 239 | 
            +
                    hidden_features = hidden_features or in_features
         | 
| 240 | 
            +
                    bias = to_2tuple(bias)
         | 
| 241 | 
            +
                    drop_probs = to_2tuple(drop)
         | 
| 242 | 
            +
                    
         | 
| 243 | 
            +
                    self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
         | 
| 244 | 
            +
                    self.act = act_layer()
         | 
| 245 | 
            +
                    self.drop1 = nn.Dropout(drop_probs[0])
         | 
| 246 | 
            +
                    self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
         | 
| 247 | 
            +
                    self.drop2 = nn.Dropout(drop_probs[1])
         | 
| 248 | 
            +
                
         | 
| 249 | 
            +
                def forward(self, x):
         | 
| 250 | 
            +
                    x = self.fc1(x)
         | 
| 251 | 
            +
                    x = self.act(x)
         | 
| 252 | 
            +
                    x = self.drop1(x)
         | 
| 253 | 
            +
                    x = self.fc2(x)
         | 
| 254 | 
            +
                    x = self.drop2(x)
         | 
| 255 | 
            +
                    return x
         | 
| 256 | 
            +
             | 
| 257 | 
            +
             | 
| 258 | 
            +
            class Block(nn.Module):
         | 
| 259 | 
            +
                
         | 
| 260 | 
            +
                def __init__(
         | 
| 261 | 
            +
                        self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
         | 
| 262 | 
            +
                        drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_flash_attn=False, use_fused_mlp=False,
         | 
| 263 | 
            +
                        fused_mlp_heuristic=1, with_cp=False, qk_normalization=False, layerscale_no_force_fp32=False,
         | 
| 264 | 
            +
                        use_fused_rmsnorm=False):
         | 
| 265 | 
            +
                    super().__init__()
         | 
| 266 | 
            +
                    
         | 
| 267 | 
            +
                    self.norm1 = norm_layer(dim)
         | 
| 268 | 
            +
                    self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
         | 
| 269 | 
            +
                                          use_flash_attn=use_flash_attn, causal=False, norm_layer=norm_layer,
         | 
| 270 | 
            +
                                          qk_normalization=qk_normalization,
         | 
| 271 | 
            +
                                          use_fused_rmsnorm=use_fused_rmsnorm)
         | 
| 272 | 
            +
                    self.ls1 = LayerScale(dim, init_values=init_values,
         | 
| 273 | 
            +
                                          force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
         | 
| 274 | 
            +
                    # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
         | 
| 275 | 
            +
                    self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
         | 
| 276 | 
            +
                    
         | 
| 277 | 
            +
                    self.norm2 = norm_layer(dim)
         | 
| 278 | 
            +
                    mlp_hidden_dim = int(dim * mlp_ratio)
         | 
| 279 | 
            +
                    if use_fused_mlp:
         | 
| 280 | 
            +
                        self.mlp = FusedMLP(in_features=dim, hidden_features=mlp_hidden_dim, heuristic=fused_mlp_heuristic)
         | 
| 281 | 
            +
                    else:
         | 
| 282 | 
            +
                        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
         | 
| 283 | 
            +
                    self.ls2 = LayerScale(dim, init_values=init_values,
         | 
| 284 | 
            +
                                          force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
         | 
| 285 | 
            +
                    self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
         | 
| 286 | 
            +
                    
         | 
| 287 | 
            +
                    self.with_cp = with_cp
         | 
| 288 | 
            +
                    self.use_fused_rmsnorm = use_fused_rmsnorm
         | 
| 289 | 
            +
                
         | 
| 290 | 
            +
                def forward(self, x, residual=None):
         | 
| 291 | 
            +
                    
         | 
| 292 | 
            +
                    def _inner_forward(x, residual=None):
         | 
| 293 | 
            +
                        if self.use_fused_rmsnorm:
         | 
| 294 | 
            +
                            x, residual = self.norm1(x, residual)
         | 
| 295 | 
            +
                            x = self.drop_path1(self.ls1(self.attn(x)))
         | 
| 296 | 
            +
                            x, residual = self.norm2(x, residual)
         | 
| 297 | 
            +
                            x = self.drop_path2(self.ls2(self.mlp(x)))
         | 
| 298 | 
            +
                            return x, residual
         | 
| 299 | 
            +
                        else:
         | 
| 300 | 
            +
                            assert residual is None
         | 
| 301 | 
            +
                            x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
         | 
| 302 | 
            +
                            x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
         | 
| 303 | 
            +
                            return x
         | 
| 304 | 
            +
                    
         | 
| 305 | 
            +
                    if self.with_cp:
         | 
| 306 | 
            +
                        return checkpoint.checkpoint(_inner_forward, x, residual)
         | 
| 307 | 
            +
                    else:
         | 
| 308 | 
            +
                        return _inner_forward(x, residual=residual)
         | 
| 309 | 
            +
             | 
| 310 | 
            +
             | 
| 311 | 
            +
            class PatchEmbed(nn.Module):
         | 
| 312 | 
            +
                """ 3D Image to Patch Embedding
         | 
| 313 | 
            +
                """
         | 
| 314 | 
            +
                
         | 
| 315 | 
            +
                def __init__(
         | 
| 316 | 
            +
                        self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, 
         | 
| 317 | 
            +
                        num_frames=8, tubelet_size=1, norm_layer=None
         | 
| 318 | 
            +
                    ):
         | 
| 319 | 
            +
                    super().__init__()
         | 
| 320 | 
            +
                    img_size = to_2tuple(img_size)
         | 
| 321 | 
            +
                    patch_size = to_2tuple(patch_size)
         | 
| 322 | 
            +
                    self.img_size = img_size
         | 
| 323 | 
            +
                    self.patch_size = patch_size
         | 
| 324 | 
            +
                    self.tubelet_size = tubelet_size
         | 
| 325 | 
            +
                    self.grid_size = (
         | 
| 326 | 
            +
                        num_frames // tubelet_size, 
         | 
| 327 | 
            +
                        img_size[0] // patch_size[0], 
         | 
| 328 | 
            +
                        img_size[1] // patch_size[1]
         | 
| 329 | 
            +
                    ) # (T, H, W)
         | 
| 330 | 
            +
                    self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
         | 
| 331 | 
            +
                    
         | 
| 332 | 
            +
                    self.proj = nn.Conv3d(
         | 
| 333 | 
            +
                        in_channels=in_chans, out_channels=embed_dim, 
         | 
| 334 | 
            +
                        kernel_size=(tubelet_size, patch_size[0], patch_size[1]), 
         | 
| 335 | 
            +
                        stride=(tubelet_size, patch_size[0], patch_size[1])
         | 
| 336 | 
            +
                    )
         | 
| 337 | 
            +
                    self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
         | 
| 338 | 
            +
                
         | 
| 339 | 
            +
                def forward(self, x):
         | 
| 340 | 
            +
                    x = self.proj(x)
         | 
| 341 | 
            +
                    x = x.flatten(3).permute(0, 2, 3, 1)  # B x C x T x HW => B x T x HW x C
         | 
| 342 | 
            +
                    x = self.norm(x)
         | 
| 343 | 
            +
                    return x
         | 
| 344 | 
            +
             | 
| 345 | 
            +
             | 
| 346 | 
            +
            class InternVideo2(nn.Module):
         | 
| 347 | 
            +
                def __init__(
         | 
| 348 | 
            +
                        self,
         | 
| 349 | 
            +
                        in_chans: int = 3,
         | 
| 350 | 
            +
                        patch_size: int = 14,
         | 
| 351 | 
            +
                        img_size: int = 224,
         | 
| 352 | 
            +
                        qkv_bias: bool = False,
         | 
| 353 | 
            +
                        drop_path_rate: float = 0.25, # may need ablation
         | 
| 354 | 
            +
                        head_drop_path_rate: float = 0.,
         | 
| 355 | 
            +
                        embed_dim: int = 1408,
         | 
| 356 | 
            +
                        num_heads: int = 16,
         | 
| 357 | 
            +
                        mlp_ratio: float = 48/11,
         | 
| 358 | 
            +
                        init_values: float = 1e-5, # may need ablation
         | 
| 359 | 
            +
                        qk_normalization: bool = True,
         | 
| 360 | 
            +
                        depth: int = 40,
         | 
| 361 | 
            +
                        use_flash_attn: bool = True,
         | 
| 362 | 
            +
                        use_fused_rmsnorm: bool = True,
         | 
| 363 | 
            +
                        use_fused_mlp: bool = True,
         | 
| 364 | 
            +
                        fused_mlp_heuristic: int = 1,
         | 
| 365 | 
            +
                        attn_pool_num_heads: int = 16,
         | 
| 366 | 
            +
                        clip_embed_dim: int = 768,
         | 
| 367 | 
            +
                        layerscale_no_force_fp32: bool = False, # when True for training?
         | 
| 368 | 
            +
                        num_frames: int = 8,
         | 
| 369 | 
            +
                        tubelet_size: int = 1,
         | 
| 370 | 
            +
                        sep_pos_embed: bool = False,
         | 
| 371 | 
            +
                        use_checkpoint: bool = False,
         | 
| 372 | 
            +
                        checkpoint_num: int = 0,
         | 
| 373 | 
            +
                    ):
         | 
| 374 | 
            +
                    super().__init__()
         | 
| 375 | 
            +
                    
         | 
| 376 | 
            +
                    assert use_flash_attn == use_fused_rmsnorm == use_fused_mlp, logger.info(
         | 
| 377 | 
            +
                        'use_flash_attn, use_fused_rmsnorm and use_fused_mlp should be consistent')
         | 
| 378 | 
            +
                    
         | 
| 379 | 
            +
                    self.use_flash_attn = use_flash_attn
         | 
| 380 | 
            +
                    self.embed_dim = embed_dim
         | 
| 381 | 
            +
                    self.T = num_frames // tubelet_size
         | 
| 382 | 
            +
                    
         | 
| 383 | 
            +
                    if use_fused_rmsnorm:
         | 
| 384 | 
            +
                        norm_layer_for_blocks = partial(DropoutAddRMSNorm, eps=1e-6, prenorm=True)
         | 
| 385 | 
            +
                    else:
         | 
| 386 | 
            +
                        norm_layer_for_blocks = partial(RMSNorm, eps=1e-6)
         | 
| 387 | 
            +
                    self.norm_layer_for_blocks = norm_layer_for_blocks
         | 
| 388 | 
            +
                    self.patch_embed = PatchEmbed(
         | 
| 389 | 
            +
                        img_size, patch_size, in_chans, embed_dim,
         | 
| 390 | 
            +
                        num_frames=num_frames, tubelet_size=tubelet_size,
         | 
| 391 | 
            +
                    )
         | 
| 392 | 
            +
                    num_patches = self.patch_embed.num_patches
         | 
| 393 | 
            +
                    self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
         | 
| 394 | 
            +
                    
         | 
| 395 | 
            +
                    # stolen from https://github.com/facebookresearch/mae_st/blob/dc072aaaf640d06892e23a33b42223a994efe272/models_vit.py#L65-L73C17
         | 
| 396 | 
            +
                    self.sep_pos_embed = sep_pos_embed
         | 
| 397 | 
            +
                    if sep_pos_embed:
         | 
| 398 | 
            +
                        logger.info("Use seperable position embedding")
         | 
| 399 | 
            +
                        grid_size = self.patch_embed.grid_size
         | 
| 400 | 
            +
                        self.grid_size = grid_size
         | 
| 401 | 
            +
                        self.pos_embed_spatial = nn.Parameter(torch.zeros(1, grid_size[1] * grid_size[2], embed_dim))
         | 
| 402 | 
            +
                        self.pos_embed_temporal = nn.Parameter(torch.zeros(1, grid_size[0], embed_dim))
         | 
| 403 | 
            +
                        self.pos_embed_cls = nn.Parameter(torch.zeros(1, 1, embed_dim))
         | 
| 404 | 
            +
                    else:
         | 
| 405 | 
            +
                        logger.info("Use joint position embedding")
         | 
| 406 | 
            +
                        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
         | 
| 407 | 
            +
                    
         | 
| 408 | 
            +
                    dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
         | 
| 409 | 
            +
                    # choose which layer to use checkpoint
         | 
| 410 | 
            +
                    with_cp_list = [False] * depth
         | 
| 411 | 
            +
                    if use_checkpoint:
         | 
| 412 | 
            +
                        for idx in range(depth):
         | 
| 413 | 
            +
                            if idx < checkpoint_num:
         | 
| 414 | 
            +
                                with_cp_list[idx] = True
         | 
| 415 | 
            +
                    logger.info(f"Droppath rate: {dpr}")
         | 
| 416 | 
            +
                    logger.info(f"Checkpoint list: {with_cp_list}")
         | 
| 417 | 
            +
                    
         | 
| 418 | 
            +
                    self.blocks = nn.ModuleList([
         | 
| 419 | 
            +
                        Block(embed_dim, num_heads, mlp_ratio, qkv_bias=qkv_bias,
         | 
| 420 | 
            +
                              norm_layer=norm_layer_for_blocks,
         | 
| 421 | 
            +
                              drop_path=dpr[i], init_values=init_values, attn_drop=0.,
         | 
| 422 | 
            +
                              use_flash_attn=use_flash_attn, use_fused_mlp=use_fused_mlp,
         | 
| 423 | 
            +
                              fused_mlp_heuristic=fused_mlp_heuristic,
         | 
| 424 | 
            +
                              with_cp=with_cp_list[i],
         | 
| 425 | 
            +
                              qk_normalization=qk_normalization,
         | 
| 426 | 
            +
                              layerscale_no_force_fp32=layerscale_no_force_fp32,
         | 
| 427 | 
            +
                              use_fused_rmsnorm=use_fused_rmsnorm)
         | 
| 428 | 
            +
                        for i in range(depth)])
         | 
| 429 | 
            +
                    self.clip_projector = AttentionPoolingBlock(
         | 
| 430 | 
            +
                        dim=embed_dim, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
         | 
| 431 | 
            +
                        drop=0., attn_drop=0., drop_path=head_drop_path_rate, 
         | 
| 432 | 
            +
                        norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim
         | 
| 433 | 
            +
                    )
         | 
| 434 | 
            +
                    
         | 
| 435 | 
            +
                    self.fc_norm = nn.Identity()
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                    self.init_pos_embed()
         | 
| 438 | 
            +
                    trunc_normal_(self.cls_token, std=.02)
         | 
| 439 | 
            +
                    self.apply(self._init_weights)
         | 
| 440 | 
            +
                    self.fix_init_weight()
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                def init_pos_embed(self):
         | 
| 443 | 
            +
                    logger.info("Init pos_embed from sincos pos_embed")
         | 
| 444 | 
            +
                    if self.sep_pos_embed:
         | 
| 445 | 
            +
                        # trunc_normal_(self.pos_embed_spatial, std=.02)
         | 
| 446 | 
            +
                        # trunc_normal_(self.pos_embed_temporal, std=.02)
         | 
| 447 | 
            +
                        # trunc_normal_(self.pos_embed_cls, std=.02)
         | 
| 448 | 
            +
                        pos_embed_spatial = get_2d_sincos_pos_embed(
         | 
| 449 | 
            +
                            self.pos_embed_spatial.shape[-1], 
         | 
| 450 | 
            +
                            self.patch_embed.grid_size[1], # height & weight
         | 
| 451 | 
            +
                        )
         | 
| 452 | 
            +
                        self.pos_embed_spatial.data.copy_(torch.from_numpy(pos_embed_spatial).float().unsqueeze(0))
         | 
| 453 | 
            +
                        pos_embed_temporal = get_1d_sincos_pos_embed(
         | 
| 454 | 
            +
                            self.pos_embed_spatial.shape[-1], 
         | 
| 455 | 
            +
                            self.patch_embed.grid_size[0], # t_size
         | 
| 456 | 
            +
                        )
         | 
| 457 | 
            +
                        self.pos_embed_temporal.data.copy_(torch.from_numpy(pos_embed_temporal).float().unsqueeze(0))
         | 
| 458 | 
            +
                    else:
         | 
| 459 | 
            +
                        # trunc_normal_(self.pos_embed, std=.02)
         | 
| 460 | 
            +
                        pos_embed = get_3d_sincos_pos_embed(
         | 
| 461 | 
            +
                            self.pos_embed.shape[-1], 
         | 
| 462 | 
            +
                            self.patch_embed.grid_size[1], # height & weight
         | 
| 463 | 
            +
                            self.patch_embed.grid_size[0], # t_size
         | 
| 464 | 
            +
                            cls_token=True
         | 
| 465 | 
            +
                        )
         | 
| 466 | 
            +
                        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
         | 
| 467 | 
            +
             | 
| 468 | 
            +
                def _init_weights(self, m):
         | 
| 469 | 
            +
                    if isinstance(m, nn.Linear):
         | 
| 470 | 
            +
                        trunc_normal_(m.weight, std=.02)
         | 
| 471 | 
            +
                        if isinstance(m, nn.Linear) and m.bias is not None:
         | 
| 472 | 
            +
                            nn.init.constant_(m.bias, 0)
         | 
| 473 | 
            +
                    elif isinstance(m, nn.LayerNorm):
         | 
| 474 | 
            +
                        nn.init.constant_(m.bias, 0)
         | 
| 475 | 
            +
                        nn.init.constant_(m.weight, 1.0)
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                def fix_init_weight(self):
         | 
| 478 | 
            +
                    def rescale(param, layer_id):
         | 
| 479 | 
            +
                        param.div_(math.sqrt(2.0 * layer_id))
         | 
| 480 | 
            +
             | 
| 481 | 
            +
                    for layer_id, layer in enumerate(self.blocks):
         | 
| 482 | 
            +
                        rescale(layer.attn.proj.weight.data, layer_id + 1)
         | 
| 483 | 
            +
                        rescale(layer.mlp.fc2.weight.data, layer_id + 1)
         | 
| 484 | 
            +
                
         | 
| 485 | 
            +
                @property
         | 
| 486 | 
            +
                def dtype(self):
         | 
| 487 | 
            +
                    return self.patch_embed.proj.weight.dtype
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                def get_num_layers(self):
         | 
| 490 | 
            +
                    return len(self.blocks)
         | 
| 491 | 
            +
             | 
| 492 | 
            +
                @torch.jit.ignore
         | 
| 493 | 
            +
                def no_weight_decay(self):
         | 
| 494 | 
            +
                    return {
         | 
| 495 | 
            +
                        'pos_embed', 
         | 
| 496 | 
            +
                        'pos_embed_spatial', 
         | 
| 497 | 
            +
                        'pos_embed_temporal', 
         | 
| 498 | 
            +
                        'pos_embed_cls',
         | 
| 499 | 
            +
                        'cls_token'
         | 
| 500 | 
            +
                    }
         | 
| 501 | 
            +
                
         | 
| 502 | 
            +
                def forward(self, x, use_image=False):
         | 
| 503 | 
            +
                    x = self.patch_embed(x.type(self.dtype))
         | 
| 504 | 
            +
                    B, T, L, C = x.shape  # T: temporal; L: spatial
         | 
| 505 | 
            +
                    x = x.view([B, T * L, C])
         | 
| 506 | 
            +
             | 
| 507 | 
            +
                    # append cls token
         | 
| 508 | 
            +
                    cls_tokens = self.cls_token.expand(B, -1, -1)
         | 
| 509 | 
            +
                    x = torch.cat((cls_tokens, x), dim=1)
         | 
| 510 | 
            +
             | 
| 511 | 
            +
                    # add pos_embed
         | 
| 512 | 
            +
                    if self.sep_pos_embed:
         | 
| 513 | 
            +
                        if use_image:
         | 
| 514 | 
            +
                            pos_embed = self.pos_embed_spatial
         | 
| 515 | 
            +
                        else:
         | 
| 516 | 
            +
                            pos_embed = self.pos_embed_spatial.repeat(
         | 
| 517 | 
            +
                                1, self.grid_size[0], 1
         | 
| 518 | 
            +
                            ) + torch.repeat_interleave(
         | 
| 519 | 
            +
                                self.pos_embed_temporal,
         | 
| 520 | 
            +
                                self.grid_size[1] * self.grid_size[2],
         | 
| 521 | 
            +
                                dim=1,
         | 
| 522 | 
            +
                            )
         | 
| 523 | 
            +
                        pos_embed = torch.cat(
         | 
| 524 | 
            +
                            [
         | 
| 525 | 
            +
                                self.pos_embed_cls.expand(pos_embed.shape[0], -1, -1),
         | 
| 526 | 
            +
                                pos_embed,
         | 
| 527 | 
            +
                            ],
         | 
| 528 | 
            +
                            1,
         | 
| 529 | 
            +
                        )
         | 
| 530 | 
            +
                    else:
         | 
| 531 | 
            +
                        if use_image:
         | 
| 532 | 
            +
                            cls_pos_embed = self.pos_embed[:, :1, :]
         | 
| 533 | 
            +
                            img_pos_embed = self.pos_embed[:, 1:, :].view(1, self.T, L, C).mean(dim=1)
         | 
| 534 | 
            +
                            pos_embed = torch.cat([cls_pos_embed, img_pos_embed], dim=1)
         | 
| 535 | 
            +
                        else:
         | 
| 536 | 
            +
                            pos_embed = self.pos_embed
         | 
| 537 | 
            +
             | 
| 538 | 
            +
                    x = x + pos_embed
         | 
| 539 | 
            +
             | 
| 540 | 
            +
                    residual = None
         | 
| 541 | 
            +
                    for blk in self.blocks:
         | 
| 542 | 
            +
                        if isinstance(x, tuple) and len(x) == 2:
         | 
| 543 | 
            +
                            x, residual = x
         | 
| 544 | 
            +
                        x = blk(x, residual=residual)
         | 
| 545 | 
            +
                    if isinstance(x, tuple) and len(x) == 2:
         | 
| 546 | 
            +
                        x, residual = x
         | 
| 547 | 
            +
                        if residual is not None:
         | 
| 548 | 
            +
                            x = x + residual
         | 
| 549 | 
            +
                    
         | 
| 550 | 
            +
                    x = self.clip_projector(x)
         | 
| 551 | 
            +
             | 
| 552 | 
            +
                    x = self.fc_norm(x)
         | 
| 553 | 
            +
                    return x
         | 
    	
        mobile_clip.py
    ADDED
    
    | @@ -0,0 +1,264 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #
         | 
| 2 | 
            +
            # For licensing see accompanying LICENSE file.
         | 
| 3 | 
            +
            # Copyright (C) 2024 Apple Inc. All Rights Reserved.
         | 
| 4 | 
            +
            #
         | 
| 5 | 
            +
            import math
         | 
| 6 | 
            +
            from typing import Optional, Sequence
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
            from torch import Tensor, nn
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from typing import Dict
         | 
| 12 | 
            +
            import open_clip
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from .mobile_clip_transformer import (
         | 
| 15 | 
            +
                PositionalEmbedding,
         | 
| 16 | 
            +
                TransformerEncoder,
         | 
| 17 | 
            +
                get_normalization_layer,
         | 
| 18 | 
            +
            )
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            class TextTransformer(nn.Module):
         | 
| 22 | 
            +
                def __init__(self, cfg: dict, projection_dim: int, *args, **kwargs) -> None:
         | 
| 23 | 
            +
                    super().__init__()
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                    model_dim = cfg["dim"]
         | 
| 26 | 
            +
                    no_scale_embedding = cfg.get("no_scale_embedding", False)
         | 
| 27 | 
            +
                    no_pos_embedding = cfg.get("no_pos_embedding", False)
         | 
| 28 | 
            +
                    embed_dropout = cfg.get("embed_dropout", 0.0)
         | 
| 29 | 
            +
                    norm_layer = cfg["norm_layer"]
         | 
| 30 | 
            +
                    variant = cfg["model_name"]
         | 
| 31 | 
            +
                    self.vocab_size = cfg["vocab_size"]
         | 
| 32 | 
            +
                    self.projection_dim = projection_dim
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    # Token embedding layer
         | 
| 35 | 
            +
                    self.embedding_layer = nn.Embedding(
         | 
| 36 | 
            +
                        embedding_dim=model_dim, num_embeddings=self.vocab_size
         | 
| 37 | 
            +
                    )
         | 
| 38 | 
            +
                    self.embed_scale = 1.0 if no_scale_embedding else model_dim**-0.5
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    # Context length
         | 
| 41 | 
            +
                    context_length = cfg["context_length"]
         | 
| 42 | 
            +
                    assert (
         | 
| 43 | 
            +
                        context_length is not None
         | 
| 44 | 
            +
                    ), "Context length can't be None. Please set value accordingly."
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    self.positional_embedding = (
         | 
| 47 | 
            +
                        None
         | 
| 48 | 
            +
                        if no_pos_embedding
         | 
| 49 | 
            +
                        else PositionalEmbedding(
         | 
| 50 | 
            +
                            num_embeddings=context_length, embedding_dim=model_dim
         | 
| 51 | 
            +
                        )
         | 
| 52 | 
            +
                    )
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    self.embedding_dropout = nn.Dropout(p=embed_dropout)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    # Transformer layer
         | 
| 57 | 
            +
                    n_transformer_layers = cfg["n_transformer_layers"]
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    # FFN multipliers for transformer layer
         | 
| 60 | 
            +
                    ffn_multipliers = cfg["ffn_multiplier_per_layer"]
         | 
| 61 | 
            +
                    if isinstance(ffn_multipliers, (float, int)):
         | 
| 62 | 
            +
                        ffn_multipliers = [ffn_multipliers] * n_transformer_layers
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    if not isinstance(ffn_multipliers, Sequence):
         | 
| 65 | 
            +
                        Warning(
         | 
| 66 | 
            +
                            "{} expects FFN multipliers as a list, whose length is the same as"
         | 
| 67 | 
            +
                            " number of transformer layers. Got: {}".format(
         | 
| 68 | 
            +
                                self.__class__.__name__, type(ffn_multipliers)
         | 
| 69 | 
            +
                            )
         | 
| 70 | 
            +
                        )
         | 
| 71 | 
            +
                    elif (
         | 
| 72 | 
            +
                        isinstance(ffn_multipliers, Sequence)
         | 
| 73 | 
            +
                        and len(ffn_multipliers) != n_transformer_layers
         | 
| 74 | 
            +
                    ):
         | 
| 75 | 
            +
                        Warning(
         | 
| 76 | 
            +
                            "We need FFN multiplier for each transformer layer. Got {} ffn"
         | 
| 77 | 
            +
                            " multipliers while number of transformer layers = {}".format(
         | 
| 78 | 
            +
                                len(ffn_multipliers), n_transformer_layers
         | 
| 79 | 
            +
                            )
         | 
| 80 | 
            +
                        )
         | 
| 81 | 
            +
                    ffn_dims = [
         | 
| 82 | 
            +
                        int(math.ceil(model_dim * ffn_mult / 16.0) * 16.0)
         | 
| 83 | 
            +
                        for ffn_mult in ffn_multipliers
         | 
| 84 | 
            +
                    ]
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    # Heads for transformer layers
         | 
| 87 | 
            +
                    mha_heads = cfg["n_heads_per_layer"]
         | 
| 88 | 
            +
                    if isinstance(mha_heads, int):
         | 
| 89 | 
            +
                        mha_heads = [mha_heads] * n_transformer_layers
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    if not isinstance(mha_heads, Sequence):
         | 
| 92 | 
            +
                        Warning(
         | 
| 93 | 
            +
                            "{} expects MHA heads as a list, whose length is the same as number of "
         | 
| 94 | 
            +
                            "transformer layers. Got: {}".format(
         | 
| 95 | 
            +
                                self.__class__.__name__, type(mha_heads)
         | 
| 96 | 
            +
                            )
         | 
| 97 | 
            +
                        )
         | 
| 98 | 
            +
                    elif isinstance(mha_heads, Sequence) and len(mha_heads) != n_transformer_layers:
         | 
| 99 | 
            +
                        Warning(
         | 
| 100 | 
            +
                            "{} needs MHA heads for each transformer layer. Got {} mha heads while"
         | 
| 101 | 
            +
                            " number of transformer layers = {}".format(
         | 
| 102 | 
            +
                                self.__class__.__name__, len(mha_heads), n_transformer_layers
         | 
| 103 | 
            +
                            )
         | 
| 104 | 
            +
                        )
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    if variant == "base":
         | 
| 107 | 
            +
                        self.transformer = nn.ModuleList(
         | 
| 108 | 
            +
                            [
         | 
| 109 | 
            +
                                TransformerEncoder(
         | 
| 110 | 
            +
                                    embed_dim=model_dim,
         | 
| 111 | 
            +
                                    num_heads=mha_heads[layer_idx],
         | 
| 112 | 
            +
                                    ffn_latent_dim=ffn_dims[layer_idx],
         | 
| 113 | 
            +
                                    transformer_norm_layer=norm_layer,
         | 
| 114 | 
            +
                                )
         | 
| 115 | 
            +
                                for layer_idx in range(n_transformer_layers)
         | 
| 116 | 
            +
                            ]
         | 
| 117 | 
            +
                        )
         | 
| 118 | 
            +
                    elif variant == "mct":
         | 
| 119 | 
            +
                        raise NotImplementedError
         | 
| 120 | 
            +
                    else:
         | 
| 121 | 
            +
                        raise ValueError("Unrecognized text encoder variant {}".format(variant))
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    self.final_layer_norm = get_normalization_layer(
         | 
| 124 | 
            +
                        num_features=model_dim, norm_type=norm_layer
         | 
| 125 | 
            +
                    )
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    self.projection_layer = nn.Parameter(
         | 
| 128 | 
            +
                        torch.empty(model_dim, self.projection_dim)
         | 
| 129 | 
            +
                    )
         | 
| 130 | 
            +
                    self.model_dim = model_dim
         | 
| 131 | 
            +
                    self.causal_masking = cfg["causal_masking"]
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                def forward_embedding(self, text_tokens: Tensor) -> Tensor:
         | 
| 134 | 
            +
                    """Return text embedding for all tokens.
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    Args:
         | 
| 137 | 
            +
                        text_tokens: a tensor of token indices. Shape: [batch_size, context_length]
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    Returns:
         | 
| 140 | 
            +
                        A tensor of [batch_size, context_length, hidden_dim].
         | 
| 141 | 
            +
                    """
         | 
| 142 | 
            +
                    # [batch_size, context_length] --> [batch_size, context_length, hidden_dim]
         | 
| 143 | 
            +
                    token_emb = self.embedding_layer(text_tokens)
         | 
| 144 | 
            +
                    seq_len = token_emb.shape[1]
         | 
| 145 | 
            +
                    if self.positional_embedding is not None:
         | 
| 146 | 
            +
                        token_emb = token_emb + self.positional_embedding(seq_len).to(
         | 
| 147 | 
            +
                            token_emb.dtype
         | 
| 148 | 
            +
                        )
         | 
| 149 | 
            +
                    token_emb = self.embedding_dropout(token_emb)
         | 
| 150 | 
            +
                    return token_emb
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                def build_attention_mask(self, context_length: int, batch_size: int) -> Tensor:
         | 
| 153 | 
            +
                    """Build causal attention mask [batch_size, context_length, context_length]."""
         | 
| 154 | 
            +
                    # Build mask with full attention between the tokens
         | 
| 155 | 
            +
                    # pytorch uses additive attention mask; fill with -inf
         | 
| 156 | 
            +
                    mask = torch.empty(context_length, context_length)
         | 
| 157 | 
            +
                    mask.fill_(float("-inf"))
         | 
| 158 | 
            +
                    mask.triu_(1)  # zero out the lower diagonal
         | 
| 159 | 
            +
                    mask = mask.unsqueeze(0)  # add dummy batch dimension
         | 
| 160 | 
            +
                    mask = mask.expand(batch_size, -1, -1)
         | 
| 161 | 
            +
                    return mask
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                def encode_text(
         | 
| 164 | 
            +
                    self,
         | 
| 165 | 
            +
                    text_tokens: Tensor,
         | 
| 166 | 
            +
                    key_padding_mask: Optional[Tensor] = None,
         | 
| 167 | 
            +
                    return_all_tokens: bool = False,
         | 
| 168 | 
            +
                    *args,
         | 
| 169 | 
            +
                    **kwargs
         | 
| 170 | 
            +
                ) -> Tensor:
         | 
| 171 | 
            +
                    """Return text token embeddings.
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    Args:
         | 
| 174 | 
            +
                        text_tokens: a tensor of token indices. Shape: [batch_size, context_length]
         | 
| 175 | 
            +
                        key_padding_mask: a tensor of boolean values as the padding mask.
         | 
| 176 | 
            +
                            Shape: [batch_size, context_length]
         | 
| 177 | 
            +
                        return_all_tokens: a boolean flag to return all tokens, defaults to False
         | 
| 178 | 
            +
                            to return only EOT token embedding.
         | 
| 179 | 
            +
                    Returns:
         | 
| 180 | 
            +
                        A tensor of [batch_size, context_length, hidden_dim] if return_all_tokens is
         | 
| 181 | 
            +
                        True, otherwise a tensor of [batch_size, hidden_dim].
         | 
| 182 | 
            +
                    """
         | 
| 183 | 
            +
                    # Discrete tokens to continuous embeddings
         | 
| 184 | 
            +
                    # [batch_size, context_length] --> [batch_size, context_length, hidden_dim]
         | 
| 185 | 
            +
                    token_emb = self.forward_embedding(text_tokens)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    # [1, context_length, context_length]
         | 
| 188 | 
            +
                    attn_mask = None
         | 
| 189 | 
            +
                    if self.causal_masking:
         | 
| 190 | 
            +
                        attn_mask = self.build_attention_mask(
         | 
| 191 | 
            +
                            context_length=text_tokens.shape[1], batch_size=text_tokens.shape[0]
         | 
| 192 | 
            +
                        )
         | 
| 193 | 
            +
                        attn_mask = attn_mask.to(device=token_emb.device, dtype=token_emb.dtype)
         | 
| 194 | 
            +
                        key_padding_mask = None
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    for layer in self.transformer:
         | 
| 197 | 
            +
                        token_emb = layer(
         | 
| 198 | 
            +
                            token_emb,
         | 
| 199 | 
            +
                            key_padding_mask=key_padding_mask,
         | 
| 200 | 
            +
                            attn_mask=attn_mask,
         | 
| 201 | 
            +
                        )
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    # Apply layer norm
         | 
| 204 | 
            +
                    token_emb = self.final_layer_norm(token_emb)
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                    if return_all_tokens:
         | 
| 207 | 
            +
                        return token_emb
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                    # Take features from the eot embedding (eot_token is the highest number in each sequence)
         | 
| 210 | 
            +
                    token_emb = token_emb[
         | 
| 211 | 
            +
                        torch.arange(text_tokens.shape[0]), text_tokens.argmax(dim=-1)
         | 
| 212 | 
            +
                    ]
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                    token_emb = token_emb @ self.projection_layer
         | 
| 215 | 
            +
                    return token_emb
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                def forward(
         | 
| 218 | 
            +
                    self,
         | 
| 219 | 
            +
                    text_tokens: Tensor,
         | 
| 220 | 
            +
                    key_padding_mask: Optional[Tensor] = None,
         | 
| 221 | 
            +
                    return_all_tokens: bool = False,
         | 
| 222 | 
            +
                    *args,
         | 
| 223 | 
            +
                    **kwargs
         | 
| 224 | 
            +
                ) -> Tensor:
         | 
| 225 | 
            +
                    # Image-text pair data with single caption
         | 
| 226 | 
            +
                    # [B, CL] --> [B, d]
         | 
| 227 | 
            +
                    text_tokens = self.encode_text(
         | 
| 228 | 
            +
                        text_tokens=text_tokens,
         | 
| 229 | 
            +
                        key_padding_mask=key_padding_mask,
         | 
| 230 | 
            +
                        return_all_tokens=return_all_tokens,
         | 
| 231 | 
            +
                        *args,
         | 
| 232 | 
            +
                        **kwargs
         | 
| 233 | 
            +
                    )
         | 
| 234 | 
            +
                    return text_tokens
         | 
| 235 | 
            +
             | 
| 236 | 
            +
             | 
| 237 | 
            +
            class ClipTokenizer(nn.Module):
         | 
| 238 | 
            +
                def __init__(self, cfg, *args, **kwargs):
         | 
| 239 | 
            +
                    super().__init__()
         | 
| 240 | 
            +
                    self.context_length = cfg["text_cfg"]["context_length"]
         | 
| 241 | 
            +
                    model_name = getattr(cfg["text_cfg"], "open_clip_tokenizer", "ViT-B-16")
         | 
| 242 | 
            +
                    self.tokenizer = open_clip.get_tokenizer(model_name)
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                def get_vocab_size(self) -> int:
         | 
| 245 | 
            +
                    return len(self.tokenizer.encoder)
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                def get_encodings(self) -> Dict[str, int]:
         | 
| 248 | 
            +
                    return self.tokenizer.encoder
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                def get_eot_token(self) -> int:
         | 
| 251 | 
            +
                    # Tokenizing an empty string returns a list [sot_id, eot_id]
         | 
| 252 | 
            +
                    return self.tokenizer("")[1]
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                def get_sot_token(self) -> int:
         | 
| 255 | 
            +
                    # Tokenizing an empty string returns a list [sot_id, eot_id]
         | 
| 256 | 
            +
                    return self.tokenizer("")[0]
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                def forward(self, input_sentence: str, *args, **kwargs) -> Tensor:
         | 
| 259 | 
            +
                    # tokenizer returns indices as a string
         | 
| 260 | 
            +
                    tokenized_sentence = self.tokenizer(input_sentence, self.context_length)
         | 
| 261 | 
            +
                    assert (
         | 
| 262 | 
            +
                        tokenized_sentence.shape[-1] == self.context_length
         | 
| 263 | 
            +
                    ), "Tokenized tensor should be exactly `context_length` long."
         | 
| 264 | 
            +
                    return tokenized_sentence
         | 
    	
        mobile_clip_transformer.py
    ADDED
    
    | @@ -0,0 +1,449 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #
         | 
| 2 | 
            +
            # For licensing see accompanying LICENSE file.
         | 
| 3 | 
            +
            # Copyright (C) 2024 Apple Inc. All Rights Reserved.
         | 
| 4 | 
            +
            #
         | 
| 5 | 
            +
            """
         | 
| 6 | 
            +
            Implementation of the following modules is borrowed from ml-cvnets repo:
         | 
| 7 | 
            +
            https://github.com/apple/ml-cvnets/blob/main/cvnets/layers/multi_head_attention.py
         | 
| 8 | 
            +
            https://github.com/apple/ml-cvnets/blob/main/cvnets/text_encoders/transformer.py
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            Please see ACKNOWLEDGEMENTS for license details.
         | 
| 11 | 
            +
            """
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from typing import List, Optional, Union
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            import torch
         | 
| 16 | 
            +
            from torch import Size, Tensor, nn
         | 
| 17 | 
            +
            from torch.nn import functional as F
         | 
| 18 | 
            +
            from torchvision.ops import StochasticDepth
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            class LayerNormFP32(nn.LayerNorm):
         | 
| 22 | 
            +
                """
         | 
| 23 | 
            +
                Applies `Layer Normalization <https://arxiv.org/abs/1607.06450>`_ over a input tensor with FP32 precision
         | 
| 24 | 
            +
                """
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                def __init__(
         | 
| 27 | 
            +
                    self,
         | 
| 28 | 
            +
                    normalized_shape: Union[int, List[int], Size],
         | 
| 29 | 
            +
                    eps: Optional[float] = 1e-5,
         | 
| 30 | 
            +
                    elementwise_affine: Optional[bool] = True,
         | 
| 31 | 
            +
                    *args,
         | 
| 32 | 
            +
                    **kwargs,
         | 
| 33 | 
            +
                ):
         | 
| 34 | 
            +
                    super().__init__(
         | 
| 35 | 
            +
                        normalized_shape=normalized_shape,
         | 
| 36 | 
            +
                        eps=eps,
         | 
| 37 | 
            +
                        elementwise_affine=elementwise_affine,
         | 
| 38 | 
            +
                        *args,
         | 
| 39 | 
            +
                        **kwargs,
         | 
| 40 | 
            +
                    )
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                def forward(self, x: Tensor) -> Tensor:
         | 
| 43 | 
            +
                    # Convert input from dtype X to FP32 and perform normalization operation.
         | 
| 44 | 
            +
                    # This may help with underflow/overflow issues that we typically see with normalization layers
         | 
| 45 | 
            +
                    inp_dtype = x.dtype
         | 
| 46 | 
            +
                    return super().forward(x.to(torch.float32)).to(inp_dtype)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            def get_normalization_layer(norm_type, num_features):
         | 
| 50 | 
            +
                if norm_type == "layer_norm":
         | 
| 51 | 
            +
                    return nn.LayerNorm(num_features)
         | 
| 52 | 
            +
                elif norm_type == "layer_norm_fp32":
         | 
| 53 | 
            +
                    return LayerNormFP32(num_features)
         | 
| 54 | 
            +
                else:
         | 
| 55 | 
            +
                    raise NotImplementedError(f"Option: {norm_type} not supported.")
         | 
| 56 | 
            +
             | 
| 57 | 
            +
             | 
| 58 | 
            +
            class PositionalEmbedding(nn.Module):
         | 
| 59 | 
            +
                def __init__(
         | 
| 60 | 
            +
                    self,
         | 
| 61 | 
            +
                    num_embeddings: int,
         | 
| 62 | 
            +
                    embedding_dim: int,
         | 
| 63 | 
            +
                    padding_idx: Optional[int] = None,
         | 
| 64 | 
            +
                    is_learnable: Optional[bool] = False,
         | 
| 65 | 
            +
                    interpolation_mode: Optional[str] = "bilinear",
         | 
| 66 | 
            +
                    *args,
         | 
| 67 | 
            +
                    **kwargs,
         | 
| 68 | 
            +
                ):
         | 
| 69 | 
            +
                    super().__init__()
         | 
| 70 | 
            +
                    # Add other pos embedding here and logic to choose between them
         | 
| 71 | 
            +
                    module = LearnablePositionalEmbedding
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    self.pos_embed = module(
         | 
| 74 | 
            +
                        num_embeddings=num_embeddings,
         | 
| 75 | 
            +
                        embedding_dim=embedding_dim,
         | 
| 76 | 
            +
                        padding_idx=padding_idx,
         | 
| 77 | 
            +
                        interpolation_mode=interpolation_mode,
         | 
| 78 | 
            +
                        *args,
         | 
| 79 | 
            +
                        **kwargs,
         | 
| 80 | 
            +
                    )
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                def forward(self, seq_len: int, *args, **kwargs) -> Tensor:
         | 
| 83 | 
            +
                    return self.pos_embed(seq_len, *args, **kwargs)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                def __repr__(self):
         | 
| 86 | 
            +
                    return self.pos_embed.__repr__()
         | 
| 87 | 
            +
             | 
| 88 | 
            +
             | 
| 89 | 
            +
            class LearnablePositionalEmbedding(nn.Module):
         | 
| 90 | 
            +
                """Learnable Positional embedding"""
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                def __init__(
         | 
| 93 | 
            +
                    self,
         | 
| 94 | 
            +
                    num_embeddings: int,
         | 
| 95 | 
            +
                    embedding_dim: int,
         | 
| 96 | 
            +
                    padding_idx: Optional[int] = None,
         | 
| 97 | 
            +
                    interpolation_mode: Optional[str] = "bilinear",
         | 
| 98 | 
            +
                    *args,
         | 
| 99 | 
            +
                    **kwargs,
         | 
| 100 | 
            +
                ):
         | 
| 101 | 
            +
                    super().__init__()
         | 
| 102 | 
            +
                    self.pos_embed = nn.Parameter(torch.empty(1, 1, num_embeddings, embedding_dim))
         | 
| 103 | 
            +
                    self.embedding_dim = embedding_dim
         | 
| 104 | 
            +
                    self.num_embeddings = num_embeddings
         | 
| 105 | 
            +
                    self.padding_idx = padding_idx
         | 
| 106 | 
            +
                    self.interpolation_mode = interpolation_mode
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    self.reset_parameters()
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                def reset_parameters(self) -> None:
         | 
| 111 | 
            +
                    nn.init.trunc_normal_(self.pos_embed, mean=0, std=self.embedding_dim**-0.5)
         | 
| 112 | 
            +
                    if self.padding_idx is not None:
         | 
| 113 | 
            +
                        with torch.no_grad():
         | 
| 114 | 
            +
                            self.pos_embed[:, :, self.padding_idx, ...] = 0.0
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                def forward(self, seq_len: int, *args, **kwargs) -> Tensor:
         | 
| 117 | 
            +
                    # scale pos embedding
         | 
| 118 | 
            +
                    pos_embed = self.pos_embed
         | 
| 119 | 
            +
                    if self.padding_idx is not None:
         | 
| 120 | 
            +
                        with torch.no_grad():
         | 
| 121 | 
            +
                            pos_embed[:, :, self.padding_idx, ...] = 0.0
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    if seq_len != self.num_embeddings:
         | 
| 124 | 
            +
                        pos_embed = F.interpolate(
         | 
| 125 | 
            +
                            pos_embed,
         | 
| 126 | 
            +
                            size=(seq_len, self.embedding_dim),
         | 
| 127 | 
            +
                            mode=self.interpolation_mode,
         | 
| 128 | 
            +
                        )
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    # Input is of the form [Batch, Seq_len, Embedding_dim]
         | 
| 131 | 
            +
                    return pos_embed.reshape(1, seq_len, self.embedding_dim)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                def __repr__(self):
         | 
| 134 | 
            +
                    return "{}(num_embeddings={}, embedding_dim={}, padding_idx={})".format(
         | 
| 135 | 
            +
                        self.__class__.__name__,
         | 
| 136 | 
            +
                        self.num_embeddings,
         | 
| 137 | 
            +
                        self.embedding_dim,
         | 
| 138 | 
            +
                        self.padding_idx,
         | 
| 139 | 
            +
                    )
         | 
| 140 | 
            +
             | 
| 141 | 
            +
             | 
| 142 | 
            +
            class MultiHeadAttention(nn.Module):
         | 
| 143 | 
            +
                """
         | 
| 144 | 
            +
                This layer applies a multi-head self- or cross-attention as described in
         | 
| 145 | 
            +
                `Attention is all you need <https://arxiv.org/abs/1706.03762>`_ paper
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                Args:
         | 
| 148 | 
            +
                    embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, S, C_{in})`
         | 
| 149 | 
            +
                    num_heads (int): Number of heads in multi-head attention
         | 
| 150 | 
            +
                    attn_dropout (Optional[float]): Attention dropout. Default: 0.0
         | 
| 151 | 
            +
                    bias (Optional[bool]): Use bias or not. Default: ``True``
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                Shape:
         | 
| 154 | 
            +
                    - Input:
         | 
| 155 | 
            +
                       - Query tensor (x_q) :math:`(N, S, C_{in})` where :math:`N` is batch size, :math:`S` is number of source tokens,
         | 
| 156 | 
            +
                    and :math:`C_{in}` is input embedding dim
         | 
| 157 | 
            +
                       - Optional Key-Value tensor (x_kv) :math:`(N, T, C_{in})` where :math:`T` is number of target tokens
         | 
| 158 | 
            +
                    - Output: same shape as the input
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                """
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                def __init__(
         | 
| 163 | 
            +
                    self,
         | 
| 164 | 
            +
                    embed_dim: int,
         | 
| 165 | 
            +
                    num_heads: int,
         | 
| 166 | 
            +
                    attn_dropout: Optional[float] = 0.0,
         | 
| 167 | 
            +
                    bias: Optional[bool] = True,
         | 
| 168 | 
            +
                    output_dim: Optional[int] = None,
         | 
| 169 | 
            +
                    *args,
         | 
| 170 | 
            +
                    **kwargs,
         | 
| 171 | 
            +
                ) -> None:
         | 
| 172 | 
            +
                    if output_dim is None:
         | 
| 173 | 
            +
                        output_dim = embed_dim
         | 
| 174 | 
            +
                    super().__init__()
         | 
| 175 | 
            +
                    if embed_dim % num_heads != 0:
         | 
| 176 | 
            +
                        Warning(
         | 
| 177 | 
            +
                            "Embedding dim must be divisible by number of heads in {}. Got: embed_dim={} and num_heads={}".format(
         | 
| 178 | 
            +
                                self.__class__.__name__, embed_dim, num_heads
         | 
| 179 | 
            +
                            )
         | 
| 180 | 
            +
                        )
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    self.qkv_proj = nn.Linear(
         | 
| 183 | 
            +
                        in_features=embed_dim, out_features=3 * embed_dim, bias=bias
         | 
| 184 | 
            +
                    )
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    self.attn_dropout = nn.Dropout(p=attn_dropout)
         | 
| 187 | 
            +
                    self.out_proj = nn.Linear(
         | 
| 188 | 
            +
                        in_features=embed_dim, out_features=output_dim, bias=bias
         | 
| 189 | 
            +
                    )
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    self.head_dim = embed_dim // num_heads
         | 
| 192 | 
            +
                    self.scaling = self.head_dim**-0.5
         | 
| 193 | 
            +
                    self.softmax = nn.Softmax(dim=-1)
         | 
| 194 | 
            +
                    self.num_heads = num_heads
         | 
| 195 | 
            +
                    self.embed_dim = embed_dim
         | 
| 196 | 
            +
                    self.use_separate_proj_weight = embed_dim != output_dim
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                def __repr__(self):
         | 
| 199 | 
            +
                    return "{}(head_dim={}, num_heads={}, attn_dropout={})".format(
         | 
| 200 | 
            +
                        self.__class__.__name__, self.head_dim, self.num_heads, self.attn_dropout.p
         | 
| 201 | 
            +
                    )
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                def _forward_impl(
         | 
| 204 | 
            +
                    self,
         | 
| 205 | 
            +
                    x_q: Tensor,
         | 
| 206 | 
            +
                    x_kv: Optional[Tensor] = None,
         | 
| 207 | 
            +
                    key_padding_mask: Optional[Tensor] = None,
         | 
| 208 | 
            +
                    attn_mask: Optional[Tensor] = None,
         | 
| 209 | 
            +
                ) -> Tensor:
         | 
| 210 | 
            +
                    # [N, S, C]
         | 
| 211 | 
            +
                    b_sz, S_len, in_channels = x_q.shape
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    if x_kv is None:
         | 
| 214 | 
            +
                        # self-attention
         | 
| 215 | 
            +
                        # [N, S, C] --> [N, S, 3C] --> [N, S, 3, h, c] where C = hc
         | 
| 216 | 
            +
                        qkv = self.qkv_proj(x_q).reshape(b_sz, S_len, 3, self.num_heads, -1)
         | 
| 217 | 
            +
                        # [N, S, 3, h, c] --> [N, h, 3, S, C]
         | 
| 218 | 
            +
                        qkv = qkv.transpose(1, 3).contiguous()
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                        # [N, h, 3, S, C] --> [N, h, S, C] x 3
         | 
| 221 | 
            +
                        query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
         | 
| 222 | 
            +
                    else:
         | 
| 223 | 
            +
                        T_len = x_kv.shape[1]
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                        # cross-attention
         | 
| 226 | 
            +
                        # [N, S, C]
         | 
| 227 | 
            +
                        query = F.linear(
         | 
| 228 | 
            +
                            x_q,
         | 
| 229 | 
            +
                            weight=self.qkv_proj.weight[: self.embed_dim, ...],
         | 
| 230 | 
            +
                            bias=self.qkv_proj.bias[: self.embed_dim]
         | 
| 231 | 
            +
                            if self.qkv_proj.bias is not None
         | 
| 232 | 
            +
                            else None,
         | 
| 233 | 
            +
                        )
         | 
| 234 | 
            +
                        # [N, S, C] --> [N, S, h, c] --> [N, h, S, c]
         | 
| 235 | 
            +
                        query = (
         | 
| 236 | 
            +
                            query.reshape(b_sz, S_len, self.num_heads, self.head_dim)
         | 
| 237 | 
            +
                            .transpose(1, 2)
         | 
| 238 | 
            +
                            .contiguous()
         | 
| 239 | 
            +
                        )
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                        # [N, T, C] --> [N, T, 2C]
         | 
| 242 | 
            +
                        kv = F.linear(
         | 
| 243 | 
            +
                            x_kv,
         | 
| 244 | 
            +
                            weight=self.qkv_proj.weight[self.embed_dim :, ...],
         | 
| 245 | 
            +
                            bias=self.qkv_proj.bias[self.embed_dim :]
         | 
| 246 | 
            +
                            if self.qkv_proj.bias is not None
         | 
| 247 | 
            +
                            else None,
         | 
| 248 | 
            +
                        )
         | 
| 249 | 
            +
                        # [N, T, 2C] --> [N, T, 2, h, c]
         | 
| 250 | 
            +
                        kv = kv.reshape(b_sz, T_len, 2, self.num_heads, self.head_dim)
         | 
| 251 | 
            +
                        # [N, T, 2, h, c] --> [N, h, 2, T, c]
         | 
| 252 | 
            +
                        kv = kv.transpose(1, 3).contiguous()
         | 
| 253 | 
            +
                        key, value = kv[:, :, 0], kv[:, :, 1]
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                    query = query * self.scaling
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                    # [N h, T, c] --> [N, h, c, T]
         | 
| 258 | 
            +
                    key = key.transpose(-1, -2)
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                    # QK^T
         | 
| 261 | 
            +
                    # [N, h, S, c] x [N, h, c, T] --> [N, h, S, T]
         | 
| 262 | 
            +
                    attn = torch.matmul(query, key)
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                    batch_size, num_heads, num_src_tokens, num_tgt_tokens = attn.shape
         | 
| 265 | 
            +
                    if attn_mask is not None:
         | 
| 266 | 
            +
                        # attn_mask shape should be the same as attn
         | 
| 267 | 
            +
                        assert list(attn_mask.shape) == [
         | 
| 268 | 
            +
                            batch_size,
         | 
| 269 | 
            +
                            num_src_tokens,
         | 
| 270 | 
            +
                            num_tgt_tokens,
         | 
| 271 | 
            +
                        ], "Shape of attention mask should be [{}, {}, {}]. Got: {}".format(
         | 
| 272 | 
            +
                            batch_size, num_src_tokens, num_tgt_tokens, attn_mask.shape
         | 
| 273 | 
            +
                        )
         | 
| 274 | 
            +
                        # [N, S, T] --> [N, 1, S, T]
         | 
| 275 | 
            +
                        attn_mask = attn_mask.unsqueeze(1)
         | 
| 276 | 
            +
                        attn = attn + attn_mask
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                    if key_padding_mask is not None:
         | 
| 279 | 
            +
                        # Do not attend to padding positions
         | 
| 280 | 
            +
                        # key padding mask size is [N, T]
         | 
| 281 | 
            +
                        assert key_padding_mask.dim() == 2 and list(key_padding_mask.shape) == [
         | 
| 282 | 
            +
                            batch_size,
         | 
| 283 | 
            +
                            num_tgt_tokens,
         | 
| 284 | 
            +
                        ], "Key_padding_mask should be 2-dimension with shape [{}, {}]. Got: {}".format(
         | 
| 285 | 
            +
                            batch_size, num_tgt_tokens, key_padding_mask.shape
         | 
| 286 | 
            +
                        )
         | 
| 287 | 
            +
                        attn = attn.masked_fill(
         | 
| 288 | 
            +
                            key_padding_mask.unsqueeze(1)
         | 
| 289 | 
            +
                            .unsqueeze(2)
         | 
| 290 | 
            +
                            .to(torch.bool),  # [N, T] --> [N, 1, 1, T]
         | 
| 291 | 
            +
                            float("-inf"),
         | 
| 292 | 
            +
                        )
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                    attn_dtype = attn.dtype
         | 
| 295 | 
            +
                    attn_as_float = self.softmax(attn.float())
         | 
| 296 | 
            +
                    attn = attn_as_float.to(attn_dtype)
         | 
| 297 | 
            +
                    attn = self.attn_dropout(attn)
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                    # weighted sum
         | 
| 300 | 
            +
                    # [N, h, S, T] x [N, h, T, c] --> [N, h, S, c]
         | 
| 301 | 
            +
                    out = torch.matmul(attn, value)
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                    # [N, h, S, c] --> [N, S, h, c] --> [N, S, C]
         | 
| 304 | 
            +
                    out = out.transpose(1, 2).reshape(b_sz, S_len, -1)
         | 
| 305 | 
            +
                    out = self.out_proj(out)
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    return out
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                def forward(
         | 
| 310 | 
            +
                    self,
         | 
| 311 | 
            +
                    x_q: Tensor,
         | 
| 312 | 
            +
                    x_kv: Optional[Tensor] = None,
         | 
| 313 | 
            +
                    key_padding_mask: Optional[Tensor] = None,
         | 
| 314 | 
            +
                    attn_mask: Optional[Tensor] = None,
         | 
| 315 | 
            +
                    *args,
         | 
| 316 | 
            +
                    **kwargs,
         | 
| 317 | 
            +
                ) -> Tensor:
         | 
| 318 | 
            +
                    # [Batch , Sequence, Hidden_dim]
         | 
| 319 | 
            +
                    return self._forward_impl(
         | 
| 320 | 
            +
                        x_q=x_q,
         | 
| 321 | 
            +
                        x_kv=x_kv,
         | 
| 322 | 
            +
                        key_padding_mask=key_padding_mask,
         | 
| 323 | 
            +
                        attn_mask=attn_mask,
         | 
| 324 | 
            +
                    )
         | 
| 325 | 
            +
             | 
| 326 | 
            +
             | 
| 327 | 
            +
            class TransformerEncoder(nn.Module):
         | 
| 328 | 
            +
                """
         | 
| 329 | 
            +
                This class defines the pre-norm `Transformer encoder <https://arxiv.org/abs/1706.03762>`_
         | 
| 330 | 
            +
                Args:
         | 
| 331 | 
            +
                    embed_dim: :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`.
         | 
| 332 | 
            +
                    ffn_latent_dim: Inner dimension of the FFN.
         | 
| 333 | 
            +
                    num_heads: Number of heads in multi-head attention. Default: 8.
         | 
| 334 | 
            +
                    attn_dropout: Dropout rate for attention in multi-head attention. Default: 0.0
         | 
| 335 | 
            +
                    dropout: Dropout rate. Default: 0.0.
         | 
| 336 | 
            +
                    ffn_dropout: Dropout between FFN layers. Default: 0.0.
         | 
| 337 | 
            +
                    transformer_norm_layer: Normalization layer. Default: layer_norm.
         | 
| 338 | 
            +
                    stochastic_dropout: Stochastic dropout setting. Default: 0.0.
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                Shape:
         | 
| 341 | 
            +
                    - Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
         | 
| 342 | 
            +
                    and :math:`C_{in}` is input embedding dim
         | 
| 343 | 
            +
                    - Output: same shape as the input
         | 
| 344 | 
            +
                """
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                def __init__(
         | 
| 347 | 
            +
                    self,
         | 
| 348 | 
            +
                    embed_dim: int,
         | 
| 349 | 
            +
                    ffn_latent_dim: int,
         | 
| 350 | 
            +
                    num_heads: Optional[int] = 8,
         | 
| 351 | 
            +
                    attn_dropout: Optional[float] = 0.0,
         | 
| 352 | 
            +
                    dropout: Optional[float] = 0.0,
         | 
| 353 | 
            +
                    ffn_dropout: Optional[float] = 0.0,
         | 
| 354 | 
            +
                    transformer_norm_layer: Optional[str] = "layer_norm",
         | 
| 355 | 
            +
                    stochastic_dropout: Optional[float] = 0.0,
         | 
| 356 | 
            +
                    *args,
         | 
| 357 | 
            +
                    **kwargs,
         | 
| 358 | 
            +
                ) -> None:
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                    super().__init__()
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                    # Build attention layer
         | 
| 363 | 
            +
                    attn_unit = MultiHeadAttention(
         | 
| 364 | 
            +
                        embed_dim,
         | 
| 365 | 
            +
                        num_heads,
         | 
| 366 | 
            +
                        attn_dropout=attn_dropout,
         | 
| 367 | 
            +
                        bias=True,
         | 
| 368 | 
            +
                    )
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                    self.pre_norm_mha = nn.Sequential(
         | 
| 371 | 
            +
                        get_normalization_layer(
         | 
| 372 | 
            +
                            norm_type=transformer_norm_layer, num_features=embed_dim
         | 
| 373 | 
            +
                        ),
         | 
| 374 | 
            +
                        attn_unit,
         | 
| 375 | 
            +
                        nn.Dropout(p=dropout),
         | 
| 376 | 
            +
                    )
         | 
| 377 | 
            +
             | 
| 378 | 
            +
                    act_name = nn.GELU()
         | 
| 379 | 
            +
                    self.pre_norm_ffn = nn.Sequential(
         | 
| 380 | 
            +
                        get_normalization_layer(
         | 
| 381 | 
            +
                            norm_type=transformer_norm_layer, num_features=embed_dim
         | 
| 382 | 
            +
                        ),
         | 
| 383 | 
            +
                        nn.Linear(in_features=embed_dim, out_features=ffn_latent_dim, bias=True),
         | 
| 384 | 
            +
                        act_name,
         | 
| 385 | 
            +
                        nn.Dropout(p=ffn_dropout),
         | 
| 386 | 
            +
                        nn.Linear(in_features=ffn_latent_dim, out_features=embed_dim, bias=True),
         | 
| 387 | 
            +
                        nn.Dropout(p=dropout),
         | 
| 388 | 
            +
                    )
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                    self.drop_path = nn.Identity()
         | 
| 391 | 
            +
                    if stochastic_dropout > 0.0:
         | 
| 392 | 
            +
                        if dropout > 0.0:
         | 
| 393 | 
            +
                            Warning(
         | 
| 394 | 
            +
                                "Stochastic dropout and dropout are mutually exclusive. "
         | 
| 395 | 
            +
                                "Use either of them, but not both."
         | 
| 396 | 
            +
                                "Got: {} and {}".format(stochastic_dropout, dropout)
         | 
| 397 | 
            +
                            )
         | 
| 398 | 
            +
                        self.drop_path = StochasticDepth(p=stochastic_dropout, mode="row")
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                    self.embed_dim = embed_dim
         | 
| 401 | 
            +
                    self.ffn_dim = ffn_latent_dim
         | 
| 402 | 
            +
                    self.ffn_dropout = ffn_dropout
         | 
| 403 | 
            +
                    self.stochastic_dropout = stochastic_dropout
         | 
| 404 | 
            +
                    self.std_dropout = dropout
         | 
| 405 | 
            +
                    self.attn_fn_name = attn_unit.__class__.__name__
         | 
| 406 | 
            +
                    self.act_fn_name = act_name.__class__.__name__
         | 
| 407 | 
            +
                    self.norm_type = transformer_norm_layer
         | 
| 408 | 
            +
             | 
| 409 | 
            +
                def __repr__(self) -> str:
         | 
| 410 | 
            +
                    return "{}(embed_dim={}, ffn_dim={}, dropout={}, ffn_dropout={}, stochastic_dropout={}, attn_fn={}, act_fn={}, norm_fn={})".format(
         | 
| 411 | 
            +
                        self.__class__.__name__,
         | 
| 412 | 
            +
                        self.embed_dim,
         | 
| 413 | 
            +
                        self.ffn_dim,
         | 
| 414 | 
            +
                        self.std_dropout,
         | 
| 415 | 
            +
                        self.ffn_dropout,
         | 
| 416 | 
            +
                        self.stochastic_dropout,
         | 
| 417 | 
            +
                        self.attn_fn_name,
         | 
| 418 | 
            +
                        self.act_fn_name,
         | 
| 419 | 
            +
                        self.norm_type,
         | 
| 420 | 
            +
                    )
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                def forward(
         | 
| 423 | 
            +
                    self,
         | 
| 424 | 
            +
                    x: Tensor,
         | 
| 425 | 
            +
                    x_prev: Optional[Tensor] = None,
         | 
| 426 | 
            +
                    key_padding_mask: Optional[Tensor] = None,
         | 
| 427 | 
            +
                    attn_mask: Optional[Tensor] = None,
         | 
| 428 | 
            +
                    *args,
         | 
| 429 | 
            +
                    **kwargs,
         | 
| 430 | 
            +
                ) -> Tensor:
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                    # Multi-head attention
         | 
| 433 | 
            +
                    res = x
         | 
| 434 | 
            +
                    x = self.pre_norm_mha[0](x)  # norm
         | 
| 435 | 
            +
                    x = self.pre_norm_mha[1](
         | 
| 436 | 
            +
                        x_q=x,
         | 
| 437 | 
            +
                        x_kv=x_prev,
         | 
| 438 | 
            +
                        key_padding_mask=key_padding_mask,
         | 
| 439 | 
            +
                        attn_mask=attn_mask,
         | 
| 440 | 
            +
                        *args,
         | 
| 441 | 
            +
                        **kwargs,
         | 
| 442 | 
            +
                    )  # mha
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                    x = self.drop_path(self.pre_norm_mha[2](x))  # applying stochastic depth
         | 
| 445 | 
            +
                    x = x + res
         | 
| 446 | 
            +
             | 
| 447 | 
            +
                    # Feed forward network
         | 
| 448 | 
            +
                    x = x + self.drop_path(self.pre_norm_ffn(x))
         | 
| 449 | 
            +
                    return x
         | 
    	
        model.safetensors
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:136b42a078a8ec440e38b56d91d570fb0969643a641795e06e171162ab176b4e
         | 
| 3 | 
            +
            size 745562274
         | 
    	
        modeling_internvideo2encoder.py
    ADDED
    
    | @@ -0,0 +1,152 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # from .internvideo2_stage2 import InternVideo2_Stage2 as IV2S2
         | 
| 2 | 
            +
            from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig
         | 
| 3 | 
            +
            from .config import InternVideo2Config as config
         | 
| 4 | 
            +
            import warnings
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            from torch import nn
         | 
| 7 | 
            +
            import torchvision.transforms as transforms
         | 
| 8 | 
            +
            from torchvision.transforms import InterpolationMode
         | 
| 9 | 
            +
            from transformers.utils import logging
         | 
| 10 | 
            +
            warnings.filterwarnings("ignore")
         | 
| 11 | 
            +
            from .internvideo2_clip_vision import InternVideo2
         | 
| 12 | 
            +
            from .mobile_clip import TextTransformer, ClipTokenizer
         | 
| 13 | 
            +
            logger = logging.get_logger(__name__)
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            class InternVideo2_CLIP_small(PreTrainedModel):
         | 
| 16 | 
            +
                config_class = config
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                def __init__(self, config,  tokenizer=None, is_pretrain=True):
         | 
| 19 | 
            +
                    super().__init__(config)
         | 
| 20 | 
            +
                    self.config = config
         | 
| 21 | 
            +
                    self.tokenizer = tokenizer
         | 
| 22 | 
            +
                    self.is_pretrain = is_pretrain
         | 
| 23 | 
            +
                    print(config)
         | 
| 24 | 
            +
                    if tokenizer is None:
         | 
| 25 | 
            +
                        self.tokenizer = ClipTokenizer(self.config.model.text_encoder)
         | 
| 26 | 
            +
                    # self.model = IV2S2(self.config).to('cpu').to(torch.float16)
         | 
| 27 | 
            +
                    self.vision_encoder = self.build_vision_encoder()
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    self.vision_align = nn.Sequential(
         | 
| 30 | 
            +
                        nn.LayerNorm(self.config.model.vision_encoder.clip_embed_dim),
         | 
| 31 | 
            +
                        nn.Linear(
         | 
| 32 | 
            +
                            self.config.model.vision_encoder.clip_embed_dim, 
         | 
| 33 | 
            +
                            self.config.model.vision_encoder.align_dim
         | 
| 34 | 
            +
                        ),
         | 
| 35 | 
            +
                    )
         | 
| 36 | 
            +
                    self.text_encoder = self.build_text_encoder(cfg=self.config.model.text_encoder['text_cfg'], projection_dim=self.config.model.text_encoder["embed_dim"])
         | 
| 37 | 
            +
                    # adopt 1 / 100. as in ViCLIP
         | 
| 38 | 
            +
                    self.temp = nn.parameter.Parameter(torch.ones([]) * config.model.temp)
         | 
| 39 | 
            +
                    self.temp_min = config.model.temp_min
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    if self.config.model.freeze_vision:
         | 
| 42 | 
            +
                        for name, p in self.vision_encoder.named_parameters():
         | 
| 43 | 
            +
                            if self.config.model.open_vision_clip_projector and name.startswith('clip_projector'):
         | 
| 44 | 
            +
                                logger.info(f"Unfreeze {name}")
         | 
| 45 | 
            +
                            else:
         | 
| 46 | 
            +
                                logger.info(f"Freeze {name}")
         | 
| 47 | 
            +
                                p.requires_grad = False
         | 
| 48 | 
            +
                    if self.config.model.freeze_text:
         | 
| 49 | 
            +
                        for name, p in self.text_encoder.named_parameters():
         | 
| 50 | 
            +
                            if self.config.model.open_text_projection and name.startswith('projection_layer'):
         | 
| 51 | 
            +
                                logger.info(f"Unfreeze {name}")
         | 
| 52 | 
            +
                            else:
         | 
| 53 | 
            +
                                logger.info(f"Freeze {name}")
         | 
| 54 | 
            +
                                p.requires_grad = False
         | 
| 55 | 
            +
                    img_size = self.config.model.vision_encoder.img_size
         | 
| 56 | 
            +
                    self.transform = transforms.Compose(
         | 
| 57 | 
            +
                        [
         | 
| 58 | 
            +
                            transforms.Resize(
         | 
| 59 | 
            +
                                (img_size, img_size),
         | 
| 60 | 
            +
                                interpolation=InterpolationMode.BICUBIC,
         | 
| 61 | 
            +
                            ),
         | 
| 62 | 
            +
                            transforms.Lambda(lambda x: x.float().div(255.0)),
         | 
| 63 | 
            +
                            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
         | 
| 64 | 
            +
                        ]
         | 
| 65 | 
            +
                    )
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
                @torch.no_grad()
         | 
| 69 | 
            +
                def clip_contrastive_temperature(self):
         | 
| 70 | 
            +
                    """Seems only used during pre-training"""
         | 
| 71 | 
            +
                    self.temp.clamp_(min=self.temp_min)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def encode_vision(self, image, test=False):
         | 
| 74 | 
            +
                    """encode image / videos as features.
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    Args:
         | 
| 77 | 
            +
                        image (torch.Tensor): The input images.
         | 
| 78 | 
            +
                        test (bool): Whether testing.
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    Returns: tuple.
         | 
| 81 | 
            +
                        - vision_embeds (torch.Tensor): The features of all patches. Shape: [B,C].
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    """
         | 
| 84 | 
            +
                    T = image.shape[1]
         | 
| 85 | 
            +
                    use_image = True if T == 1 else False
         | 
| 86 | 
            +
                    image = image.permute(0, 2, 1, 3, 4) # [B,T,C,H,W] -> [B,C,T,H,W]
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    vision_embeds = self.vision_encoder(image, use_image=use_image)
         | 
| 89 | 
            +
                    vision_embeds = self.vision_align(vision_embeds)
         | 
| 90 | 
            +
                    return vision_embeds
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                def encode_text(self, text):
         | 
| 93 | 
            +
                    """encode text.
         | 
| 94 | 
            +
                    Args:
         | 
| 95 | 
            +
                        text (dict): The output of huggingface's `PreTrainedTokenizer`. contains keys:
         | 
| 96 | 
            +
                            - input_ids (torch.Tensor): Token ids to be fed to a model. Shape: [B,L].
         | 
| 97 | 
            +
                            - attention_mask (torch.Tensor): The mask indicate padded tokens. Shape: [B,L]. 0 is padded token.
         | 
| 98 | 
            +
                            - other keys refer to "https://huggingface.co/docs/transformers/v4.21.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__".
         | 
| 99 | 
            +
                    Returns: tuple.
         | 
| 100 | 
            +
                        - text_embeds (torch.Tensor): The features of all tokens. Shape: [B,C].
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    """
         | 
| 103 | 
            +
                    text_embeds = self.text_encoder(text)
         | 
| 104 | 
            +
                    return text_embeds
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                def build_vision_encoder(self):
         | 
| 107 | 
            +
                    """build vision encoder
         | 
| 108 | 
            +
                    Returns: (vision_encoder, vision_layernorm). Each is a `nn.Module`.
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    """
         | 
| 111 | 
            +
                    vision_encoder = InternVideo2(
         | 
| 112 | 
            +
                        in_chans=self.config.model.vision_encoder.in_chans,
         | 
| 113 | 
            +
                        patch_size=self.config.model.vision_encoder.patch_size,
         | 
| 114 | 
            +
                        img_size=self.config.model.vision_encoder.img_size,
         | 
| 115 | 
            +
                        qkv_bias=self.config.model.vision_encoder.qkv_bias,
         | 
| 116 | 
            +
                        drop_path_rate=self.config.model.vision_encoder.drop_path_rate,
         | 
| 117 | 
            +
                        head_drop_path_rate=self.config.model.vision_encoder.head_drop_path_rate,
         | 
| 118 | 
            +
                        embed_dim=self.config.model.vision_encoder.embed_dim,
         | 
| 119 | 
            +
                        num_heads=self.config.model.vision_encoder.num_heads,
         | 
| 120 | 
            +
                        mlp_ratio=self.config.model.vision_encoder.mlp_ratio,
         | 
| 121 | 
            +
                        init_values=self.config.model.vision_encoder.init_values,
         | 
| 122 | 
            +
                        qk_normalization=self.config.model.vision_encoder.qk_normalization,
         | 
| 123 | 
            +
                        depth=self.config.model.vision_encoder.depth,
         | 
| 124 | 
            +
                        use_flash_attn=self.config.model.vision_encoder.use_flash_attn,
         | 
| 125 | 
            +
                        use_fused_rmsnorm=self.config.model.vision_encoder.use_fused_rmsnorm,
         | 
| 126 | 
            +
                        use_fused_mlp=self.config.model.vision_encoder.use_fused_mlp,
         | 
| 127 | 
            +
                        fused_mlp_heuristic=self.config.model.vision_encoder.fused_mlp_heuristic,
         | 
| 128 | 
            +
                        attn_pool_num_heads=self.config.model.vision_encoder.attn_pool_num_heads,
         | 
| 129 | 
            +
                        clip_embed_dim=self.config.model.vision_encoder.clip_embed_dim,
         | 
| 130 | 
            +
                        layerscale_no_force_fp32=self.config.model.vision_encoder.layerscale_no_force_fp32,
         | 
| 131 | 
            +
                        num_frames=self.config.model.vision_encoder.num_frames,
         | 
| 132 | 
            +
                        tubelet_size=self.config.model.vision_encoder.tubelet_size,
         | 
| 133 | 
            +
                        sep_pos_embed=self.config.model.vision_encoder.sep_pos_embed,
         | 
| 134 | 
            +
                        use_checkpoint=self.config.model.vision_encoder.use_checkpoint,
         | 
| 135 | 
            +
                        checkpoint_num=self.config.model.vision_encoder.checkpoint_num,
         | 
| 136 | 
            +
                    )
         | 
| 137 | 
            +
                    return vision_encoder
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                def build_text_encoder(self, cfg, projection_dim):
         | 
| 140 | 
            +
                    """build text_encoder and possiblly video-to-text multimodal fusion encoder.
         | 
| 141 | 
            +
                    Returns: nn.Module. The text encoder
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    """
         | 
| 144 | 
            +
                    text_encoder = TextTransformer(cfg, projection_dim)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    return text_encoder
         | 
| 147 | 
            +
                
         | 
| 148 | 
            +
            if __name__ == "__main__":
         | 
| 149 | 
            +
                model_config = config()
         | 
| 150 | 
            +
                model = InternVideo2Stage2VideoEncoder(model_config)
         | 
| 151 | 
            +
                x = torch.randn(2, 3, 8, 224, 224, dtype=torch.float16).to(model_config.device)
         | 
| 152 | 
            +
                output = model(x)
         | 
    	
        pos_embed.py
    ADDED
    
    | @@ -0,0 +1,299 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import logging
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            # --------------------------------------------------------
         | 
| 8 | 
            +
            # 3D sine-cosine position embedding
         | 
| 9 | 
            +
            # References:
         | 
| 10 | 
            +
            # MVD: https://github.com/ruiwang2021/mvd/blob/main/modeling_finetune.py
         | 
| 11 | 
            +
            # --------------------------------------------------------
         | 
| 12 | 
            +
            def get_3d_sincos_pos_embed(embed_dim, grid_size, t_size, cls_token=False):
         | 
| 13 | 
            +
                """
         | 
| 14 | 
            +
                grid_size: int of the grid height and width
         | 
| 15 | 
            +
                t_size: int of the temporal size
         | 
| 16 | 
            +
                return:
         | 
| 17 | 
            +
                pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
         | 
| 18 | 
            +
                """
         | 
| 19 | 
            +
                assert embed_dim % 4 == 0
         | 
| 20 | 
            +
                embed_dim_spatial = embed_dim // 4 * 3
         | 
| 21 | 
            +
                embed_dim_temporal = embed_dim // 4
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                # spatial
         | 
| 24 | 
            +
                grid_h = np.arange(grid_size, dtype=np.float32)
         | 
| 25 | 
            +
                grid_w = np.arange(grid_size, dtype=np.float32)
         | 
| 26 | 
            +
                grid = np.meshgrid(grid_w, grid_h)  # here w goes first
         | 
| 27 | 
            +
                grid = np.stack(grid, axis=0)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                grid = grid.reshape([2, 1, grid_size, grid_size])
         | 
| 30 | 
            +
                pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(
         | 
| 31 | 
            +
                    embed_dim_spatial, grid
         | 
| 32 | 
            +
                )
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                # temporal
         | 
| 35 | 
            +
                grid_t = np.arange(t_size, dtype=np.float32)
         | 
| 36 | 
            +
                pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(
         | 
| 37 | 
            +
                    embed_dim_temporal, grid_t
         | 
| 38 | 
            +
                )
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                # concate: [T, H, W] order
         | 
| 41 | 
            +
                pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
         | 
| 42 | 
            +
                pos_embed_temporal = np.repeat(
         | 
| 43 | 
            +
                    pos_embed_temporal, grid_size**2, axis=1
         | 
| 44 | 
            +
                )  # [T, H*W, D // 4]
         | 
| 45 | 
            +
                pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
         | 
| 46 | 
            +
                pos_embed_spatial = np.repeat(
         | 
| 47 | 
            +
                    pos_embed_spatial, t_size, axis=0
         | 
| 48 | 
            +
                )  # [T, H*W, D // 4 * 3]
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
         | 
| 51 | 
            +
                pos_embed = pos_embed.reshape([-1, embed_dim])  # [T*H*W, D]
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                if cls_token:
         | 
| 54 | 
            +
                    pos_embed = np.concatenate(
         | 
| 55 | 
            +
                        [np.zeros([1, embed_dim]), pos_embed], axis=0
         | 
| 56 | 
            +
                    )
         | 
| 57 | 
            +
                return pos_embed
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            # --------------------------------------------------------
         | 
| 61 | 
            +
            # 2D sine-cosine position embedding
         | 
| 62 | 
            +
            # References:
         | 
| 63 | 
            +
            # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
         | 
| 64 | 
            +
            # MoCo v3: https://github.com/facebookresearch/moco-v3
         | 
| 65 | 
            +
            # --------------------------------------------------------
         | 
| 66 | 
            +
            def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
         | 
| 67 | 
            +
                """
         | 
| 68 | 
            +
                grid_size: int of the grid height and width
         | 
| 69 | 
            +
                return:
         | 
| 70 | 
            +
                pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
         | 
| 71 | 
            +
                """
         | 
| 72 | 
            +
                grid_h = np.arange(grid_size, dtype=np.float32)
         | 
| 73 | 
            +
                grid_w = np.arange(grid_size, dtype=np.float32)
         | 
| 74 | 
            +
                grid = np.meshgrid(grid_w, grid_h)  # here w goes first
         | 
| 75 | 
            +
                grid = np.stack(grid, axis=0)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                grid = grid.reshape([2, 1, grid_size, grid_size])
         | 
| 78 | 
            +
                pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
         | 
| 79 | 
            +
                if cls_token:
         | 
| 80 | 
            +
                    pos_embed = np.concatenate(
         | 
| 81 | 
            +
                        [np.zeros([1, embed_dim]), pos_embed], axis=0
         | 
| 82 | 
            +
                    )
         | 
| 83 | 
            +
                return pos_embed
         | 
| 84 | 
            +
             | 
| 85 | 
            +
             | 
| 86 | 
            +
            def get_1d_sincos_pos_embed(embed_dim, t_size, cls_token=False):
         | 
| 87 | 
            +
                """
         | 
| 88 | 
            +
                t_size: int of the temporal size
         | 
| 89 | 
            +
                return:
         | 
| 90 | 
            +
                pos_embed: [t_size, embed_dim] or [1+t_size, embed_dim] (w/ or w/o cls_token)
         | 
| 91 | 
            +
                """
         | 
| 92 | 
            +
                grid_t = np.arange(t_size, dtype=np.float32)
         | 
| 93 | 
            +
                pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_t)
         | 
| 94 | 
            +
                if cls_token:
         | 
| 95 | 
            +
                    pos_embed = np.concatenate(
         | 
| 96 | 
            +
                        [np.zeros([1, embed_dim]), pos_embed], axis=0
         | 
| 97 | 
            +
                    )
         | 
| 98 | 
            +
                return pos_embed
         | 
| 99 | 
            +
             | 
| 100 | 
            +
             | 
| 101 | 
            +
            def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
         | 
| 102 | 
            +
                assert embed_dim % 2 == 0
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                # use half of dimensions to encode grid_h
         | 
| 105 | 
            +
                emb_h = get_1d_sincos_pos_embed_from_grid(
         | 
| 106 | 
            +
                    embed_dim // 2, grid[0]
         | 
| 107 | 
            +
                )  # (H*W, D/2)
         | 
| 108 | 
            +
                emb_w = get_1d_sincos_pos_embed_from_grid(
         | 
| 109 | 
            +
                    embed_dim // 2, grid[1]
         | 
| 110 | 
            +
                )  # (H*W, D/2)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
         | 
| 113 | 
            +
                return emb
         | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
            def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
         | 
| 117 | 
            +
                """
         | 
| 118 | 
            +
                embed_dim: output dimension for each position
         | 
| 119 | 
            +
                pos: a list of positions to be encoded: size (M,)
         | 
| 120 | 
            +
                out: (M, D)
         | 
| 121 | 
            +
                """
         | 
| 122 | 
            +
                assert embed_dim % 2 == 0
         | 
| 123 | 
            +
                omega = np.arange(embed_dim // 2, dtype=np.float32)
         | 
| 124 | 
            +
                omega /= embed_dim / 2.0
         | 
| 125 | 
            +
                omega = 1.0 / 10000**omega  # (D/2,)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                pos = pos.reshape(-1)  # (M,)
         | 
| 128 | 
            +
                out = np.einsum("m,d->md", pos, omega)  # (M, D/2), outer product
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                emb_sin = np.sin(out)  # (M, D/2)
         | 
| 131 | 
            +
                emb_cos = np.cos(out)  # (M, D/2)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
         | 
| 134 | 
            +
                return emb
         | 
| 135 | 
            +
             | 
| 136 | 
            +
             | 
| 137 | 
            +
            def interpolate_pos_embed(checkpoint_model, model, orig_t_size=4, pos_name='vision_encoder.pos_embed'):
         | 
| 138 | 
            +
                if pos_name in checkpoint_model:
         | 
| 139 | 
            +
                    pos_embed_checkpoint = checkpoint_model[pos_name]
         | 
| 140 | 
            +
                    embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
         | 
| 141 | 
            +
                    num_patches = model.patch_embed.num_patches # 
         | 
| 142 | 
            +
                    num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    # we use 4 frames for pretraining
         | 
| 145 | 
            +
                    new_t_size = model.T
         | 
| 146 | 
            +
                    # height (== width) for the checkpoint position embedding
         | 
| 147 | 
            +
                    orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5)
         | 
| 148 | 
            +
                    # height (== width) for the new position embedding
         | 
| 149 | 
            +
                    new_size = int((num_patches // (new_t_size))** 0.5)
         | 
| 150 | 
            +
                    
         | 
| 151 | 
            +
                    # class_token and dist_token are kept unchanged
         | 
| 152 | 
            +
                    if orig_t_size != new_t_size:
         | 
| 153 | 
            +
                        logger.info(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})")
         | 
| 154 | 
            +
                        extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
         | 
| 155 | 
            +
                        # only the position tokens are interpolated
         | 
| 156 | 
            +
                        pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
         | 
| 157 | 
            +
                        # B, L, C -> B, T, HW, C -> BHW, C, T  (B = 1)
         | 
| 158 | 
            +
                        pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size)
         | 
| 159 | 
            +
                        pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size)
         | 
| 160 | 
            +
                        pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear')
         | 
| 161 | 
            +
                        pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size)
         | 
| 162 | 
            +
                        pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size)
         | 
| 163 | 
            +
                        new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
         | 
| 164 | 
            +
                        checkpoint_model[pos_name] = new_pos_embed
         | 
| 165 | 
            +
                        pos_embed_checkpoint = new_pos_embed
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                    # class_token and dist_token are kept unchanged
         | 
| 168 | 
            +
                    if orig_size != new_size:
         | 
| 169 | 
            +
                        logger.info(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})")
         | 
| 170 | 
            +
                        extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
         | 
| 171 | 
            +
                        # only the position tokens are interpolated
         | 
| 172 | 
            +
                        pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
         | 
| 173 | 
            +
                        # B, L, C -> BT, H, W, C -> BT, C, H, W
         | 
| 174 | 
            +
                        pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size)
         | 
| 175 | 
            +
                        pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
         | 
| 176 | 
            +
                        pos_tokens = torch.nn.functional.interpolate(
         | 
| 177 | 
            +
                            pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
         | 
| 178 | 
            +
                        # BT, C, H, W -> BT, H, W, C ->  B, T, H, W, C
         | 
| 179 | 
            +
                        pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size) 
         | 
| 180 | 
            +
                        pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
         | 
| 181 | 
            +
                        new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
         | 
| 182 | 
            +
                        checkpoint_model[pos_name] = new_pos_embed
         | 
| 183 | 
            +
             | 
| 184 | 
            +
             | 
| 185 | 
            +
            def interpolate_pos_embed_internvideo2(checkpoint_model, model, orig_t_size = 8):
         | 
| 186 | 
            +
                # interpolate position embedding
         | 
| 187 | 
            +
                for pos_name in ['pos_embed', 'clip_pos_embed']:
         | 
| 188 | 
            +
                    if pos_name in checkpoint_model:
         | 
| 189 | 
            +
                        pos_embed_checkpoint = checkpoint_model[pos_name]
         | 
| 190 | 
            +
                        embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
         | 
| 191 | 
            +
                        num_patches = model.patch_embed.num_patches # 
         | 
| 192 | 
            +
                        num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                        # we use 8 frames for pretraining
         | 
| 195 | 
            +
                        # new_t_size = args.num_frames * args.num_segments // model.patch_embed.tubelet_size
         | 
| 196 | 
            +
                        new_t_size = model.num_frames // model.tubelet_size
         | 
| 197 | 
            +
                        # height (== width) for the checkpoint position embedding
         | 
| 198 | 
            +
                        orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5)
         | 
| 199 | 
            +
                        # height (== width) for the new position embedding
         | 
| 200 | 
            +
                        new_size = int((num_patches // (new_t_size))** 0.5)
         | 
| 201 | 
            +
                        
         | 
| 202 | 
            +
                        # class_token and dist_token are kept unchanged
         | 
| 203 | 
            +
                        if orig_t_size != new_t_size:
         | 
| 204 | 
            +
                            logger.info(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})")
         | 
| 205 | 
            +
                            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
         | 
| 206 | 
            +
                            # only the position tokens are interpolated
         | 
| 207 | 
            +
                            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
         | 
| 208 | 
            +
                            # B, L, C -> B, T, HW, C -> BHW, C, T  (B = 1)
         | 
| 209 | 
            +
                            pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size)
         | 
| 210 | 
            +
                            pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size)
         | 
| 211 | 
            +
                            pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear')
         | 
| 212 | 
            +
                            pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size)
         | 
| 213 | 
            +
                            pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size)
         | 
| 214 | 
            +
                            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
         | 
| 215 | 
            +
                            checkpoint_model[pos_name] = new_pos_embed
         | 
| 216 | 
            +
                            pos_embed_checkpoint = new_pos_embed
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                        # class_token and dist_token are kept unchanged
         | 
| 219 | 
            +
                        if orig_size != new_size:
         | 
| 220 | 
            +
                            logger.info(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})")
         | 
| 221 | 
            +
                            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
         | 
| 222 | 
            +
                            # only the position tokens are interpolated
         | 
| 223 | 
            +
                            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
         | 
| 224 | 
            +
                            # B, L, C -> BT, H, W, C -> BT, C, H, W
         | 
| 225 | 
            +
                            pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size)
         | 
| 226 | 
            +
                            pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
         | 
| 227 | 
            +
                            pos_tokens = torch.nn.functional.interpolate(
         | 
| 228 | 
            +
                                pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
         | 
| 229 | 
            +
                            # BT, C, H, W -> BT, H, W, C ->  B, T, H, W, C
         | 
| 230 | 
            +
                            pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size) 
         | 
| 231 | 
            +
                            pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
         | 
| 232 | 
            +
                            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
         | 
| 233 | 
            +
                            checkpoint_model[pos_name] = new_pos_embed
         | 
| 234 | 
            +
                
         | 
| 235 | 
            +
                if 'pos_embed_spatial' in checkpoint_model or 'pos_embed_temporal' in checkpoint_model:
         | 
| 236 | 
            +
                    raise NotImplementedError
         | 
| 237 | 
            +
             | 
| 238 | 
            +
             | 
| 239 | 
            +
            def interpolate_pos_embed_internvideo2_new(checkpoint_model, model, orig_t_size = 8):
         | 
| 240 | 
            +
                pos_names = []
         | 
| 241 | 
            +
                for k in checkpoint_model.keys():
         | 
| 242 | 
            +
                    if ('pos_embed' in k or 'clip_pos_embed' in k) and 'img_pos_embed' not in k:
         | 
| 243 | 
            +
                        pos_names.append(k)
         | 
| 244 | 
            +
                
         | 
| 245 | 
            +
                logger.info(f"pos names list for interpolating: {pos_names}")
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                assert len(pos_names) > 0, checkpoint_model.keys()
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                if 'pos_embed_spatial' in checkpoint_model.keys() or 'pos_embed_temporal' in checkpoint_model.keys():
         | 
| 250 | 
            +
                    raise NotImplementedError
         | 
| 251 | 
            +
                
         | 
| 252 | 
            +
                # interpolate position embedding
         | 
| 253 | 
            +
                for pos_name in pos_names:
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                    pos_embed_checkpoint = checkpoint_model[pos_name]
         | 
| 256 | 
            +
                    embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
         | 
| 257 | 
            +
                    num_patches = model.patch_embed.num_patches # 
         | 
| 258 | 
            +
                    num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                    # we use 8 frames for pretraining
         | 
| 261 | 
            +
                    # new_t_size = args.num_frames * args.num_segments // model.patch_embed.tubelet_size
         | 
| 262 | 
            +
                    new_t_size = model.num_frames // model.tubelet_size
         | 
| 263 | 
            +
                    # height (== width) for the checkpoint position embedding
         | 
| 264 | 
            +
                    orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5)
         | 
| 265 | 
            +
                    # height (== width) for the new position embedding
         | 
| 266 | 
            +
                    new_size = int((num_patches // (new_t_size))** 0.5)
         | 
| 267 | 
            +
                    
         | 
| 268 | 
            +
                    # class_token and dist_token are kept unchanged
         | 
| 269 | 
            +
                    if orig_t_size != new_t_size:
         | 
| 270 | 
            +
                        logger.info(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})")
         | 
| 271 | 
            +
                        extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
         | 
| 272 | 
            +
                        # only the position tokens are interpolated
         | 
| 273 | 
            +
                        pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
         | 
| 274 | 
            +
                        # B, L, C -> B, T, HW, C -> BHW, C, T  (B = 1)
         | 
| 275 | 
            +
                        pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size)
         | 
| 276 | 
            +
                        pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size)
         | 
| 277 | 
            +
                        pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear')
         | 
| 278 | 
            +
                        pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size)
         | 
| 279 | 
            +
                        pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size)
         | 
| 280 | 
            +
                        new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
         | 
| 281 | 
            +
                        checkpoint_model[pos_name] = new_pos_embed
         | 
| 282 | 
            +
                        pos_embed_checkpoint = new_pos_embed
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                    # class_token and dist_token are kept unchanged
         | 
| 285 | 
            +
                    if orig_size != new_size:
         | 
| 286 | 
            +
                        logger.info(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})")
         | 
| 287 | 
            +
                        extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
         | 
| 288 | 
            +
                        # only the position tokens are interpolated
         | 
| 289 | 
            +
                        pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
         | 
| 290 | 
            +
                        # B, L, C -> BT, H, W, C -> BT, C, H, W
         | 
| 291 | 
            +
                        pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size)
         | 
| 292 | 
            +
                        pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
         | 
| 293 | 
            +
                        pos_tokens = torch.nn.functional.interpolate(
         | 
| 294 | 
            +
                            pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
         | 
| 295 | 
            +
                        # BT, C, H, W -> BT, H, W, C ->  B, T, H, W, C
         | 
| 296 | 
            +
                        pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size) 
         | 
| 297 | 
            +
                        pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
         | 
| 298 | 
            +
                        new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
         | 
| 299 | 
            +
                        checkpoint_model[pos_name] = new_pos_embed
         | 
    	
        test.ipynb
    ADDED
    
    | @@ -0,0 +1,424 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
             "cells": [
         | 
| 3 | 
            +
              {
         | 
| 4 | 
            +
               "cell_type": "code",
         | 
| 5 | 
            +
               "execution_count": 1,
         | 
| 6 | 
            +
               "metadata": {
         | 
| 7 | 
            +
                "metadata": {}
         | 
| 8 | 
            +
               },
         | 
| 9 | 
            +
               "outputs": [
         | 
| 10 | 
            +
                {
         | 
| 11 | 
            +
                 "name": "stderr",
         | 
| 12 | 
            +
                 "output_type": "stream",
         | 
| 13 | 
            +
                 "text": [
         | 
| 14 | 
            +
                  "/root/miniconda3/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
         | 
| 15 | 
            +
                  "  from .autonotebook import tqdm as notebook_tqdm\n"
         | 
| 16 | 
            +
                 ]
         | 
| 17 | 
            +
                },
         | 
| 18 | 
            +
                {
         | 
| 19 | 
            +
                 "name": "stdout",
         | 
| 20 | 
            +
                 "output_type": "stream",
         | 
| 21 | 
            +
                 "text": [
         | 
| 22 | 
            +
                  "InternVideo2Config {\n",
         | 
| 23 | 
            +
                  "  \"_attn_implementation_autoset\": true,\n",
         | 
| 24 | 
            +
                  "  \"architectures\": [\n",
         | 
| 25 | 
            +
                  "    \"InternVideo2_CLIP_small\"\n",
         | 
| 26 | 
            +
                  "  ],\n",
         | 
| 27 | 
            +
                  "  \"auto_map\": {\n",
         | 
| 28 | 
            +
                  "    \"AutoConfig\": \"config.InternVideo2Config\",\n",
         | 
| 29 | 
            +
                  "    \"AutoModel\": \"modeling_internvideo2encoder.InternVideo2_CLIP_small\"\n",
         | 
| 30 | 
            +
                  "  },\n",
         | 
| 31 | 
            +
                  "  \"auto_resume\": false,\n",
         | 
| 32 | 
            +
                  "  \"batch_size\": 64,\n",
         | 
| 33 | 
            +
                  "  \"batch_size_test\": 4,\n",
         | 
| 34 | 
            +
                  "  \"best_key\": [\n",
         | 
| 35 | 
            +
                  "    \"msrvtt_1k_test_match\",\n",
         | 
| 36 | 
            +
                  "    \"t2v_r1\"\n",
         | 
| 37 | 
            +
                  "  ],\n",
         | 
| 38 | 
            +
                  "  \"compile_model\": false,\n",
         | 
| 39 | 
            +
                  "  \"criterion\": {\n",
         | 
| 40 | 
            +
                  "    \"clip_loss_ratio\": [\n",
         | 
| 41 | 
            +
                  "      1.0,\n",
         | 
| 42 | 
            +
                  "      1.0\n",
         | 
| 43 | 
            +
                  "    ],\n",
         | 
| 44 | 
            +
                  "    \"distill_final_features\": true,\n",
         | 
| 45 | 
            +
                  "    \"loss_weight\": {\n",
         | 
| 46 | 
            +
                  "      \"mlm\": 1.0,\n",
         | 
| 47 | 
            +
                  "      \"mvm\": 0.0,\n",
         | 
| 48 | 
            +
                  "      \"uta\": 0.0,\n",
         | 
| 49 | 
            +
                  "      \"vtc\": 1.0,\n",
         | 
| 50 | 
            +
                  "      \"vtm\": 1.0\n",
         | 
| 51 | 
            +
                  "    },\n",
         | 
| 52 | 
            +
                  "    \"mlm_masking_prob\": 0.5,\n",
         | 
| 53 | 
            +
                  "    \"vtm_hard_neg\": true\n",
         | 
| 54 | 
            +
                  "  },\n",
         | 
| 55 | 
            +
                  "  \"debug\": false,\n",
         | 
| 56 | 
            +
                  "  \"deep_fusion\": false,\n",
         | 
| 57 | 
            +
                  "  \"deepspeed\": {\n",
         | 
| 58 | 
            +
                  "    \"enable\": true,\n",
         | 
| 59 | 
            +
                  "    \"stage\": 1\n",
         | 
| 60 | 
            +
                  "  },\n",
         | 
| 61 | 
            +
                  "  \"delete_ds_optim_states\": true,\n",
         | 
| 62 | 
            +
                  "  \"device\": \"cuda\",\n",
         | 
| 63 | 
            +
                  "  \"dist_url\": \"env://\",\n",
         | 
| 64 | 
            +
                  "  \"evaluate\": false,\n",
         | 
| 65 | 
            +
                  "  \"evaluation\": {\n",
         | 
| 66 | 
            +
                  "    \"eval_frame_ensemble\": \"concat\",\n",
         | 
| 67 | 
            +
                  "    \"eval_offload\": true,\n",
         | 
| 68 | 
            +
                  "    \"eval_x_only\": false,\n",
         | 
| 69 | 
            +
                  "    \"k_test\": 128\n",
         | 
| 70 | 
            +
                  "  },\n",
         | 
| 71 | 
            +
                  "  \"gradient_checkpointing\": true,\n",
         | 
| 72 | 
            +
                  "  \"inputs\": {\n",
         | 
| 73 | 
            +
                  "    \"batch_size\": {\n",
         | 
| 74 | 
            +
                  "      \"image\": 64,\n",
         | 
| 75 | 
            +
                  "      \"video\": 64\n",
         | 
| 76 | 
            +
                  "    },\n",
         | 
| 77 | 
            +
                  "    \"batch_size_test\": {\n",
         | 
| 78 | 
            +
                  "      \"image\": 4,\n",
         | 
| 79 | 
            +
                  "      \"video\": 4\n",
         | 
| 80 | 
            +
                  "    },\n",
         | 
| 81 | 
            +
                  "    \"image_res\": 224,\n",
         | 
| 82 | 
            +
                  "    \"max_txt_l\": {\n",
         | 
| 83 | 
            +
                  "      \"image\": 32,\n",
         | 
| 84 | 
            +
                  "      \"video\": 32\n",
         | 
| 85 | 
            +
                  "    },\n",
         | 
| 86 | 
            +
                  "    \"video_input\": {\n",
         | 
| 87 | 
            +
                  "      \"num_frames\": 8,\n",
         | 
| 88 | 
            +
                  "      \"num_frames_test\": 8,\n",
         | 
| 89 | 
            +
                  "      \"random_aug\": false,\n",
         | 
| 90 | 
            +
                  "      \"sample_type\": \"middle\",\n",
         | 
| 91 | 
            +
                  "      \"sample_type_test\": \"middle\"\n",
         | 
| 92 | 
            +
                  "    }\n",
         | 
| 93 | 
            +
                  "  },\n",
         | 
| 94 | 
            +
                  "  \"jump_evaluate\": false,\n",
         | 
| 95 | 
            +
                  "  \"log_freq\": 100,\n",
         | 
| 96 | 
            +
                  "  \"max_txt_l\": 32,\n",
         | 
| 97 | 
            +
                  "  \"mode\": \"pt\",\n",
         | 
| 98 | 
            +
                  "  \"model\": {\n",
         | 
| 99 | 
            +
                  "    \"embed_dim\": 1024,\n",
         | 
| 100 | 
            +
                  "    \"find_unused_parameters\": false,\n",
         | 
| 101 | 
            +
                  "    \"freeze_text\": true,\n",
         | 
| 102 | 
            +
                  "    \"freeze_vision\": true,\n",
         | 
| 103 | 
            +
                  "    \"load_vision_ckpt_from_internvideo2_stage2\": false,\n",
         | 
| 104 | 
            +
                  "    \"model_cls\": \"InternVideo2_CLIP_small\",\n",
         | 
| 105 | 
            +
                  "    \"multimodal\": {\n",
         | 
| 106 | 
            +
                  "      \"enable\": true\n",
         | 
| 107 | 
            +
                  "    },\n",
         | 
| 108 | 
            +
                  "    \"open_text_projection\": false,\n",
         | 
| 109 | 
            +
                  "    \"open_vision_clip_projector\": true,\n",
         | 
| 110 | 
            +
                  "    \"temp\": 0.01,\n",
         | 
| 111 | 
            +
                  "    \"temp_min\": 0.01,\n",
         | 
| 112 | 
            +
                  "    \"text_encoder\": {\n",
         | 
| 113 | 
            +
                  "      \"embed_dim\": 512,\n",
         | 
| 114 | 
            +
                  "      \"image_cfg\": {\n",
         | 
| 115 | 
            +
                  "        \"image_size\": 224,\n",
         | 
| 116 | 
            +
                  "        \"model_name\": \"vit_b16\"\n",
         | 
| 117 | 
            +
                  "      },\n",
         | 
| 118 | 
            +
                  "      \"text_cfg\": {\n",
         | 
| 119 | 
            +
                  "        \"causal_masking\": true,\n",
         | 
| 120 | 
            +
                  "        \"context_length\": 77,\n",
         | 
| 121 | 
            +
                  "        \"dim\": 512,\n",
         | 
| 122 | 
            +
                  "        \"ffn_multiplier_per_layer\": 4.0,\n",
         | 
| 123 | 
            +
                  "        \"model_name\": \"base\",\n",
         | 
| 124 | 
            +
                  "        \"n_heads_per_layer\": 8,\n",
         | 
| 125 | 
            +
                  "        \"n_transformer_layers\": 12,\n",
         | 
| 126 | 
            +
                  "        \"norm_layer\": \"layer_norm_fp32\",\n",
         | 
| 127 | 
            +
                  "        \"vocab_size\": 49408\n",
         | 
| 128 | 
            +
                  "      }\n",
         | 
| 129 | 
            +
                  "    },\n",
         | 
| 130 | 
            +
                  "    \"vision_encoder\": {\n",
         | 
| 131 | 
            +
                  "      \"align_dim\": 512,\n",
         | 
| 132 | 
            +
                  "      \"attn_pool_num_heads\": 16,\n",
         | 
| 133 | 
            +
                  "      \"checkpoint_num\": 0,\n",
         | 
| 134 | 
            +
                  "      \"clip_embed_dim\": 768,\n",
         | 
| 135 | 
            +
                  "      \"depth\": 24,\n",
         | 
| 136 | 
            +
                  "      \"drop_cls_token\": false,\n",
         | 
| 137 | 
            +
                  "      \"drop_path_rate\": 0.0,\n",
         | 
| 138 | 
            +
                  "      \"embed_dim\": 1024,\n",
         | 
| 139 | 
            +
                  "      \"fused_mlp_heuristic\": 1,\n",
         | 
| 140 | 
            +
                  "      \"head_drop_path_rate\": 0.0,\n",
         | 
| 141 | 
            +
                  "      \"img_size\": 224,\n",
         | 
| 142 | 
            +
                  "      \"in_chans\": 3,\n",
         | 
| 143 | 
            +
                  "      \"init_values\": 0.1,\n",
         | 
| 144 | 
            +
                  "      \"layerscale_no_force_fp32\": true,\n",
         | 
| 145 | 
            +
                  "      \"mlp_ratio\": 4,\n",
         | 
| 146 | 
            +
                  "      \"name\": \"internvideo2_1B\",\n",
         | 
| 147 | 
            +
                  "      \"num_frames\": 8,\n",
         | 
| 148 | 
            +
                  "      \"num_heads\": 16,\n",
         | 
| 149 | 
            +
                  "      \"patch_size\": 14,\n",
         | 
| 150 | 
            +
                  "      \"qk_normalization\": true,\n",
         | 
| 151 | 
            +
                  "      \"qkv_bias\": false,\n",
         | 
| 152 | 
            +
                  "      \"sep_pos_embed\": false,\n",
         | 
| 153 | 
            +
                  "      \"tubelet_size\": 1,\n",
         | 
| 154 | 
            +
                  "      \"use_checkpoint\": false,\n",
         | 
| 155 | 
            +
                  "      \"use_flash_attn\": false,\n",
         | 
| 156 | 
            +
                  "      \"use_fused_mlp\": false,\n",
         | 
| 157 | 
            +
                  "      \"use_fused_rmsnorm\": false\n",
         | 
| 158 | 
            +
                  "    }\n",
         | 
| 159 | 
            +
                  "  },\n",
         | 
| 160 | 
            +
                  "  \"model_type\": \"internvideo2\",\n",
         | 
| 161 | 
            +
                  "  \"num_frames\": 8,\n",
         | 
| 162 | 
            +
                  "  \"num_frames_test\": 8,\n",
         | 
| 163 | 
            +
                  "  \"num_workers\": 6,\n",
         | 
| 164 | 
            +
                  "  \"optimizer\": {\n",
         | 
| 165 | 
            +
                  "    \"different_lr\": {\n",
         | 
| 166 | 
            +
                  "      \"enable\": false,\n",
         | 
| 167 | 
            +
                  "      \"lr\": 0.001,\n",
         | 
| 168 | 
            +
                  "      \"module_names\": []\n",
         | 
| 169 | 
            +
                  "    },\n",
         | 
| 170 | 
            +
                  "    \"lr\": 5e-05,\n",
         | 
| 171 | 
            +
                  "    \"max_grad_norm\": 3.0,\n",
         | 
| 172 | 
            +
                  "    \"opt\": \"adamW\",\n",
         | 
| 173 | 
            +
                  "    \"opt_betas\": [\n",
         | 
| 174 | 
            +
                  "      0.9,\n",
         | 
| 175 | 
            +
                  "      0.98\n",
         | 
| 176 | 
            +
                  "    ],\n",
         | 
| 177 | 
            +
                  "    \"weight_decay\": 0.05\n",
         | 
| 178 | 
            +
                  "  },\n",
         | 
| 179 | 
            +
                  "  \"output_dir\": null,\n",
         | 
| 180 | 
            +
                  "  \"pretrained_path\": \"\",\n",
         | 
| 181 | 
            +
                  "  \"resume\": false,\n",
         | 
| 182 | 
            +
                  "  \"save_ckpt_iter\": null,\n",
         | 
| 183 | 
            +
                  "  \"save_latest\": true,\n",
         | 
| 184 | 
            +
                  "  \"scheduler\": {\n",
         | 
| 185 | 
            +
                  "    \"epochs\": 10,\n",
         | 
| 186 | 
            +
                  "    \"min_lr_multi\": 0.01,\n",
         | 
| 187 | 
            +
                  "    \"sched\": \"cosine\",\n",
         | 
| 188 | 
            +
                  "    \"warmup_epochs\": 1\n",
         | 
| 189 | 
            +
                  "  },\n",
         | 
| 190 | 
            +
                  "  \"seed\": 42,\n",
         | 
| 191 | 
            +
                  "  \"test_file\": {\n",
         | 
| 192 | 
            +
                  "    \"didemo_ret_test\": \"available_corpus[\\\"didemo_ret_test\\\"]\",\n",
         | 
| 193 | 
            +
                  "    \"msrvtt_1k_test\": \"available_corpus[\\\"msrvtt_1k_test\\\"]\"\n",
         | 
| 194 | 
            +
                  "  },\n",
         | 
| 195 | 
            +
                  "  \"test_types\": [\n",
         | 
| 196 | 
            +
                  "    \"msrvtt_1k_test\",\n",
         | 
| 197 | 
            +
                  "    \"didemo_ret_test\"\n",
         | 
| 198 | 
            +
                  "  ],\n",
         | 
| 199 | 
            +
                  "  \"text_enc\": \"bert_large\",\n",
         | 
| 200 | 
            +
                  "  \"tokenizer\": null,\n",
         | 
| 201 | 
            +
                  "  \"torch_dtype\": \"float32\",\n",
         | 
| 202 | 
            +
                  "  \"train_file\": \"available_corpus[\\\"pretrain_example_data_1B\\\"]\",\n",
         | 
| 203 | 
            +
                  "  \"transformers_version\": \"4.51.3\",\n",
         | 
| 204 | 
            +
                  "  \"use_bf16\": true,\n",
         | 
| 205 | 
            +
                  "  \"use_flash_sdp\": false,\n",
         | 
| 206 | 
            +
                  "  \"use_half_precision\": false,\n",
         | 
| 207 | 
            +
                  "  \"use_mem_efficient_sdp\": false,\n",
         | 
| 208 | 
            +
                  "  \"wandb\": {\n",
         | 
| 209 | 
            +
                  "    \"enable\": false,\n",
         | 
| 210 | 
            +
                  "    \"entity\": \"opengvlab\",\n",
         | 
| 211 | 
            +
                  "    \"project\": \"InternVideo2-Stage2\"\n",
         | 
| 212 | 
            +
                  "  }\n",
         | 
| 213 | 
            +
                  "}\n",
         | 
| 214 | 
            +
                  "\n"
         | 
| 215 | 
            +
                 ]
         | 
| 216 | 
            +
                }
         | 
| 217 | 
            +
               ],
         | 
| 218 | 
            +
               "source": [
         | 
| 219 | 
            +
                "from transformers import AutoConfig, AutoModel\n",
         | 
| 220 | 
            +
                "config = AutoConfig.from_pretrained(\"/fs-computility/video/heyinan/iv2hf/\", trust_remote_code=True)\n",
         | 
| 221 | 
            +
                "model = AutoModel.from_pretrained(\"/fs-computility/video/heyinan/iv2hf/\", trust_remote_code=True).to(config.device)"
         | 
| 222 | 
            +
               ]
         | 
| 223 | 
            +
              },
         | 
| 224 | 
            +
              {
         | 
| 225 | 
            +
               "cell_type": "code",
         | 
| 226 | 
            +
               "execution_count": 2,
         | 
| 227 | 
            +
               "metadata": {
         | 
| 228 | 
            +
                "metadata": {}
         | 
| 229 | 
            +
               },
         | 
| 230 | 
            +
               "outputs": [],
         | 
| 231 | 
            +
               "source": [
         | 
| 232 | 
            +
                "import os\n",
         | 
| 233 | 
            +
                "import random\n",
         | 
| 234 | 
            +
                "import io\n",
         | 
| 235 | 
            +
                "import av\n",
         | 
| 236 | 
            +
                "import cv2\n",
         | 
| 237 | 
            +
                "import decord\n",
         | 
| 238 | 
            +
                "import imageio\n",
         | 
| 239 | 
            +
                "from decord import VideoReader\n",
         | 
| 240 | 
            +
                "import torch\n",
         | 
| 241 | 
            +
                "import numpy as np\n",
         | 
| 242 | 
            +
                "import math\n",
         | 
| 243 | 
            +
                "import torch.nn.functional as F\n",
         | 
| 244 | 
            +
                "decord.bridge.set_bridge(\"torch\")\n",
         | 
| 245 | 
            +
                "\n",
         | 
| 246 | 
            +
                "\n",
         | 
| 247 | 
            +
                "def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1, start=None, end=None):\n",
         | 
| 248 | 
            +
                "    start_frame, end_frame = 0, vlen\n",
         | 
| 249 | 
            +
                "    if start is not None:\n",
         | 
| 250 | 
            +
                "        start_frame = max(start_frame,int(start * input_fps))\n",
         | 
| 251 | 
            +
                "    if end is not None:\n",
         | 
| 252 | 
            +
                "        end_frame = min(end_frame,int(end * input_fps))\n",
         | 
| 253 | 
            +
                "\n",
         | 
| 254 | 
            +
                "    # Ensure start_frame is less than end_frame\n",
         | 
| 255 | 
            +
                "    if start_frame >= end_frame:\n",
         | 
| 256 | 
            +
                "        raise ValueError(\"Start frame index must be less than end frame index\")\n",
         | 
| 257 | 
            +
                "\n",
         | 
| 258 | 
            +
                "    # Calculate the length of the clip in frames\n",
         | 
| 259 | 
            +
                "    clip_length = end_frame - start_frame\n",
         | 
| 260 | 
            +
                "\n",
         | 
| 261 | 
            +
                "    if sample in [\"rand\", \"middle\"]:  # uniform sampling\n",
         | 
| 262 | 
            +
                "        acc_samples = min(num_frames, clip_length)\n",
         | 
| 263 | 
            +
                "        # split the clip into `acc_samples` intervals, and sample from each interval.\n",
         | 
| 264 | 
            +
                "        intervals = np.linspace(start=start_frame, stop=end_frame, num=acc_samples + 1).astype(int)\n",
         | 
| 265 | 
            +
                "        ranges = []\n",
         | 
| 266 | 
            +
                "        for idx, interv in enumerate(intervals[:-1]):\n",
         | 
| 267 | 
            +
                "            ranges.append((interv, intervals[idx + 1] - 1))\n",
         | 
| 268 | 
            +
                "        if sample == 'rand':\n",
         | 
| 269 | 
            +
                "            try:\n",
         | 
| 270 | 
            +
                "                frame_indices = [random.choice(range(x[0], x[1] + 1)) for x in ranges]\n",
         | 
| 271 | 
            +
                "            except:\n",
         | 
| 272 | 
            +
                "                frame_indices = np.random.permutation(clip_length)[:acc_samples] + start_frame\n",
         | 
| 273 | 
            +
                "                frame_indices.sort()\n",
         | 
| 274 | 
            +
                "                frame_indices = list(frame_indices)\n",
         | 
| 275 | 
            +
                "        elif fix_start is not None:\n",
         | 
| 276 | 
            +
                "            frame_indices = [x[0] + fix_start for x in ranges]\n",
         | 
| 277 | 
            +
                "        elif sample == 'middle':\n",
         | 
| 278 | 
            +
                "            frame_indices = [(x[0] + x[1]) // 2 for x in ranges]\n",
         | 
| 279 | 
            +
                "        else:\n",
         | 
| 280 | 
            +
                "            raise NotImplementedError\n",
         | 
| 281 | 
            +
                "\n",
         | 
| 282 | 
            +
                "        if len(frame_indices) < num_frames:  # padded with last frame\n",
         | 
| 283 | 
            +
                "            padded_frame_indices = [frame_indices[-1]] * num_frames\n",
         | 
| 284 | 
            +
                "            padded_frame_indices[:len(frame_indices)] = frame_indices\n",
         | 
| 285 | 
            +
                "            frame_indices = padded_frame_indices\n",
         | 
| 286 | 
            +
                "    elif \"fps\" in sample:  # fps0.5, sequentially sample frames at 0.5 fps\n",
         | 
| 287 | 
            +
                "        output_fps = float(sample[3:])\n",
         | 
| 288 | 
            +
                "        duration = float(clip_length) / input_fps\n",
         | 
| 289 | 
            +
                "        delta = 1 / output_fps  # gap between frames, this is also the clip length each frame represents\n",
         | 
| 290 | 
            +
                "        frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)\n",
         | 
| 291 | 
            +
                "        frame_indices = np.around(frame_seconds * input_fps).astype(int) + start_frame\n",
         | 
| 292 | 
            +
                "        frame_indices = [e for e in frame_indices if e < end_frame]\n",
         | 
| 293 | 
            +
                "        if max_num_frames > 0 and len(frame_indices) > max_num_frames:\n",
         | 
| 294 | 
            +
                "            frame_indices = frame_indices[:max_num_frames]\n",
         | 
| 295 | 
            +
                "            # frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)\n",
         | 
| 296 | 
            +
                "    else:\n",
         | 
| 297 | 
            +
                "        raise ValueError\n",
         | 
| 298 | 
            +
                "    return frame_indices\n",
         | 
| 299 | 
            +
                "\n",
         | 
| 300 | 
            +
                "def read_frames_decord(\n",
         | 
| 301 | 
            +
                "        video_path, num_frames, sample='middle', fix_start=None, \n",
         | 
| 302 | 
            +
                "        max_num_frames=-1, client=None, trimmed30=False, start=None, end=None\n",
         | 
| 303 | 
            +
                "    ):\n",
         | 
| 304 | 
            +
                "    num_threads = 1 if video_path.endswith('.webm') else 0 # make ssv2 happy\n",
         | 
| 305 | 
            +
                "\n",
         | 
| 306 | 
            +
                "    video_reader = VideoReader(video_path, num_threads=num_threads)\n",
         | 
| 307 | 
            +
                "    vlen = len(video_reader)\n",
         | 
| 308 | 
            +
                " \n",
         | 
| 309 | 
            +
                "    fps = video_reader.get_avg_fps()\n",
         | 
| 310 | 
            +
                "    duration = vlen / float(fps)\n",
         | 
| 311 | 
            +
                "\n",
         | 
| 312 | 
            +
                "    frame_indices = get_frame_indices(\n",
         | 
| 313 | 
            +
                "        num_frames, vlen, sample=sample, fix_start=fix_start,\n",
         | 
| 314 | 
            +
                "        input_fps=fps, max_num_frames=max_num_frames, start=start, end=end\n",
         | 
| 315 | 
            +
                "    )\n",
         | 
| 316 | 
            +
                "\n",
         | 
| 317 | 
            +
                "    frames = video_reader.get_batch(frame_indices)  # (T, H, W, C), torch.uint8\n",
         | 
| 318 | 
            +
                "    frames = frames.permute(0, 3, 1, 2)  # (T, C, H, W), torch.uint8\n",
         | 
| 319 | 
            +
                "    return frames, frame_indices, duration"
         | 
| 320 | 
            +
               ]
         | 
| 321 | 
            +
              },
         | 
| 322 | 
            +
              {
         | 
| 323 | 
            +
               "cell_type": "code",
         | 
| 324 | 
            +
               "execution_count": 3,
         | 
| 325 | 
            +
               "metadata": {
         | 
| 326 | 
            +
                "metadata": {}
         | 
| 327 | 
            +
               },
         | 
| 328 | 
            +
               "outputs": [],
         | 
| 329 | 
            +
               "source": [
         | 
| 330 | 
            +
                "def get_text_feature(model, texts):\n",
         | 
| 331 | 
            +
                "    text_input = model.tokenizer(texts).to(model.device)\n",
         | 
| 332 | 
            +
                "    text_features = model.encode_text(text_input)\n",
         | 
| 333 | 
            +
                "    return text_features\n",
         | 
| 334 | 
            +
                "    \n",
         | 
| 335 | 
            +
                "def get_similarity(video_feature, text_feature):\n",
         | 
| 336 | 
            +
                "    video_feature = F.normalize(video_feature, dim=-1)\n",
         | 
| 337 | 
            +
                "    text_feature = F.normalize(text_feature, dim=-1)\n",
         | 
| 338 | 
            +
                "    sim_matrix = text_feature @ video_feature.T\n",
         | 
| 339 | 
            +
                "    return sim_matrix"
         | 
| 340 | 
            +
               ]
         | 
| 341 | 
            +
              },
         | 
| 342 | 
            +
              {
         | 
| 343 | 
            +
               "cell_type": "code",
         | 
| 344 | 
            +
               "execution_count": 12,
         | 
| 345 | 
            +
               "metadata": {
         | 
| 346 | 
            +
                "metadata": {}
         | 
| 347 | 
            +
               },
         | 
| 348 | 
            +
               "outputs": [],
         | 
| 349 | 
            +
               "source": [
         | 
| 350 | 
            +
                "def get_top_videos(model, text_features, video_features, video_paths, texts):\n",
         | 
| 351 | 
            +
                "    # text_features = get_text_feature(texts)\n",
         | 
| 352 | 
            +
                "\n",
         | 
| 353 | 
            +
                "    video_features = F.normalize(video_features, dim=-1)\n",
         | 
| 354 | 
            +
                "    text_features = F.normalize(text_features, dim=-1)\n",
         | 
| 355 | 
            +
                "\n",
         | 
| 356 | 
            +
                "    # print(text_features.shape, video_features.shape)\n",
         | 
| 357 | 
            +
                "    sim_matrix = text_features @ video_features.T\n",
         | 
| 358 | 
            +
                "    # print(sim_matrix.shape)\n",
         | 
| 359 | 
            +
                "\n",
         | 
| 360 | 
            +
                "    top_k = 5\n",
         | 
| 361 | 
            +
                "    sim_matrix_top_k = torch.topk(sim_matrix, top_k, dim=1)[1]\n",
         | 
| 362 | 
            +
                "    softmax_sim_matrix = F.softmax(sim_matrix, dim=1)\n",
         | 
| 363 | 
            +
                "\n",
         | 
| 364 | 
            +
                "    retrieval_infos = {}\n",
         | 
| 365 | 
            +
                "    for i in range(len(sim_matrix_top_k)):\n",
         | 
| 366 | 
            +
                "        print(\"\\n\",texts[i])\n",
         | 
| 367 | 
            +
                "        retrieval_infos[texts[i]] = []\n",
         | 
| 368 | 
            +
                "        for j in range(top_k):\n",
         | 
| 369 | 
            +
                "            print(\"top\", j+1, \":\", video_paths[sim_matrix_top_k[i][j]], \"~prob:\", sim_matrix[i][sim_matrix_top_k[i][j]].item())\n",
         | 
| 370 | 
            +
                "            retrieval_infos[texts[i]].append({\"video\":  video_paths[sim_matrix_top_k[i][j]], \"prob\": sim_matrix[i][sim_matrix_top_k[i][j]].item(), \"rank\": j+1})\n",
         | 
| 371 | 
            +
                "    return retrieval_infos"
         | 
| 372 | 
            +
               ]
         | 
| 373 | 
            +
              },
         | 
| 374 | 
            +
              {
         | 
| 375 | 
            +
               "cell_type": "code",
         | 
| 376 | 
            +
               "execution_count": null,
         | 
| 377 | 
            +
               "metadata": {
         | 
| 378 | 
            +
                "metadata": {}
         | 
| 379 | 
            +
               },
         | 
| 380 | 
            +
               "outputs": [],
         | 
| 381 | 
            +
               "source": [
         | 
| 382 | 
            +
                "if __name__==\"__main__\":\n",
         | 
| 383 | 
            +
                "    video_features = []\n",
         | 
| 384 | 
            +
                "    demo_videos = [\"video-scene-00030.mp4\",\"video-scene-00031.mp4\",\"xinhuashe_test_video/video-scene-00032.mp4\",\"xinhuashe_test_video/video-scene-00033.mp4\",\"video-scene-00034.mp4\"]\n",
         | 
| 385 | 
            +
                "    texts = ['a person talking', 'a logo', 'a building']\n",
         | 
| 386 | 
            +
                "    for video_path in demo_videos:\n",
         | 
| 387 | 
            +
                "        frames, frame_indices, video_duration = read_frames_decord(video_path,8)\n",
         | 
| 388 | 
            +
                "        frames = model.transform(frames).unsqueeze(0).to(model.device)\n",
         | 
| 389 | 
            +
                "        # 获得视频特征\n",
         | 
| 390 | 
            +
                "        with torch.no_grad():\n",
         | 
| 391 | 
            +
                "            video_feature = model.encode_vision(frames, test=True)\n",
         | 
| 392 | 
            +
                "            video_features.append(video_feature)\n",
         | 
| 393 | 
            +
                "    \n",
         | 
| 394 | 
            +
                "    # # 获得文本特征\n",
         | 
| 395 | 
            +
                "    text_features = get_text_feature(model, texts)\n",
         | 
| 396 | 
            +
                "    video_features = torch.cat(video_features, dim=0).to(text_features.dtype).to(config.device)\n",
         | 
| 397 | 
            +
                "    results = get_top_videos(model, text_features, video_features, demo_videos, texts)\n",
         | 
| 398 | 
            +
                "\n",
         | 
| 399 | 
            +
                "\n"
         | 
| 400 | 
            +
               ]
         | 
| 401 | 
            +
              }
         | 
| 402 | 
            +
             ],
         | 
| 403 | 
            +
             "metadata": {
         | 
| 404 | 
            +
              "kernelspec": {
         | 
| 405 | 
            +
               "display_name": "base",
         | 
| 406 | 
            +
               "language": "python",
         | 
| 407 | 
            +
               "name": "python3"
         | 
| 408 | 
            +
              },
         | 
| 409 | 
            +
              "language_info": {
         | 
| 410 | 
            +
               "codemirror_mode": {
         | 
| 411 | 
            +
                "name": "ipython",
         | 
| 412 | 
            +
                "version": 3
         | 
| 413 | 
            +
               },
         | 
| 414 | 
            +
               "file_extension": ".py",
         | 
| 415 | 
            +
               "mimetype": "text/x-python",
         | 
| 416 | 
            +
               "name": "python",
         | 
| 417 | 
            +
               "nbconvert_exporter": "python",
         | 
| 418 | 
            +
               "pygments_lexer": "ipython3",
         | 
| 419 | 
            +
               "version": "3.10.15"
         | 
| 420 | 
            +
              }
         | 
| 421 | 
            +
             },
         | 
| 422 | 
            +
             "nbformat": 4,
         | 
| 423 | 
            +
             "nbformat_minor": 2
         | 
| 424 | 
            +
            }
         | 

