Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Upload 2 files
Browse files
    	
        MeshAnything/models/meshanything_v2.py
    ADDED
    
    | @@ -0,0 +1,162 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn.functional as nnf
         | 
| 3 | 
            +
            from torch import nn
         | 
| 4 | 
            +
            import random
         | 
| 5 | 
            +
            from transformers import AutoModelForCausalLM
         | 
| 6 | 
            +
            from MeshAnything.miche.encode import load_model
         | 
| 7 | 
            +
            from MeshAnything.models.shape_opt import ShapeOPTConfig
         | 
| 8 | 
            +
            from einops import repeat, reduce, rearrange, pack, unpack
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            class MeshAnythingV2(nn.Module):
         | 
| 11 | 
            +
                def __init__(self):
         | 
| 12 | 
            +
                    super().__init__()
         | 
| 13 | 
            +
                    self.point_encoder = load_model(ckpt_path=None)
         | 
| 14 | 
            +
                    self.n_discrete_size = 128
         | 
| 15 | 
            +
                    self.max_seq_ratio = 0.70
         | 
| 16 | 
            +
                    self.face_per_token = 9
         | 
| 17 | 
            +
                    self.cond_length = 257
         | 
| 18 | 
            +
                    self.cond_dim = 768
         | 
| 19 | 
            +
                    self.pad_id = -1
         | 
| 20 | 
            +
                    self.n_max_triangles = 1600
         | 
| 21 | 
            +
                    self.max_length = int(self.n_max_triangles * self.face_per_token * self.max_seq_ratio + 3 + self.cond_length) # add 1
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    self.coor_continuous_range = (-0.5, 0.5)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                    self.config = ShapeOPTConfig.from_pretrained(
         | 
| 26 | 
            +
                        "facebook/opt-350m",
         | 
| 27 | 
            +
                        n_positions=self.max_length,
         | 
| 28 | 
            +
                        max_position_embeddings=self.max_length,
         | 
| 29 | 
            +
                        vocab_size=self.n_discrete_size + 4,
         | 
| 30 | 
            +
                        _attn_implementation="flash_attention_2"
         | 
| 31 | 
            +
                    )
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    self.bos_token_id = 0
         | 
| 34 | 
            +
                    self.eos_token_id = 1
         | 
| 35 | 
            +
                    self.pad_token_id = 2
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    self.config.bos_token_id = self.bos_token_id
         | 
| 38 | 
            +
                    self.config.eos_token_id = self.eos_token_id
         | 
| 39 | 
            +
                    self.config.pad_token_id = self.pad_token_id
         | 
| 40 | 
            +
                    self.config._attn_implementation="flash_attention_2"
         | 
| 41 | 
            +
                    self.config.n_discrete_size = self.n_discrete_size
         | 
| 42 | 
            +
                    self.config.face_per_token = self.face_per_token
         | 
| 43 | 
            +
                    self.config.cond_length = self.cond_length
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    if self.config.word_embed_proj_dim != self.config.hidden_size:
         | 
| 46 | 
            +
                        self.config.word_embed_proj_dim = self.config.hidden_size
         | 
| 47 | 
            +
                    self.transformer = AutoModelForCausalLM.from_config(
         | 
| 48 | 
            +
                        config=self.config, use_flash_attention_2 = True
         | 
| 49 | 
            +
                    )
         | 
| 50 | 
            +
                    self.transformer.to_bettertransformer()
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    self.cond_head_proj = nn.Linear(self.cond_dim, self.config.word_embed_proj_dim)
         | 
| 53 | 
            +
                    self.cond_proj = nn.Linear(self.cond_dim * 2, self.config.word_embed_proj_dim)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    self.eval()
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def adjacent_detokenize(self, input_ids):
         | 
| 58 | 
            +
                    input_ids = input_ids.reshape(input_ids.shape[0], -1) # B x L
         | 
| 59 | 
            +
                    batch_size = input_ids.shape[0]
         | 
| 60 | 
            +
                    continuous_coors = torch.zeros((batch_size, self.n_max_triangles * 3 * 10, 3), device=input_ids.device)
         | 
| 61 | 
            +
                    continuous_coors[...] = float('nan')
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    for i in range(batch_size):
         | 
| 64 | 
            +
                        cur_ids = input_ids[i]
         | 
| 65 | 
            +
                        coor_loop_check = 0
         | 
| 66 | 
            +
                        vertice_count = 0
         | 
| 67 | 
            +
                        continuous_coors[i, :3, :] = torch.tensor([[-0.1, 0.0, 0.1], [-0.1, 0.1, 0.2], [-0.3, 0.3, 0.2]],
         | 
| 68 | 
            +
                                                                  device=input_ids.device)
         | 
| 69 | 
            +
                        for id in cur_ids:
         | 
| 70 | 
            +
                            if id == self.pad_id:
         | 
| 71 | 
            +
                                break
         | 
| 72 | 
            +
                            elif id == self.n_discrete_size:
         | 
| 73 | 
            +
                                if coor_loop_check < 9:
         | 
| 74 | 
            +
                                    break
         | 
| 75 | 
            +
                                if coor_loop_check % 3 !=0:
         | 
| 76 | 
            +
                                    break
         | 
| 77 | 
            +
                                coor_loop_check = 0
         | 
| 78 | 
            +
                            else:
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                                if coor_loop_check % 3 == 0 and coor_loop_check >= 9:
         | 
| 81 | 
            +
                                    continuous_coors[i, vertice_count] = continuous_coors[i, vertice_count-2]
         | 
| 82 | 
            +
                                    continuous_coors[i, vertice_count+1] = continuous_coors[i, vertice_count-1]
         | 
| 83 | 
            +
                                    vertice_count += 2
         | 
| 84 | 
            +
                                continuous_coors[i, vertice_count, coor_loop_check % 3] = undiscretize(id, self.coor_continuous_range[0], self.coor_continuous_range[1], self.n_discrete_size)
         | 
| 85 | 
            +
                                if coor_loop_check % 3 == 2:
         | 
| 86 | 
            +
                                    vertice_count += 1
         | 
| 87 | 
            +
                                coor_loop_check += 1
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    continuous_coors = rearrange(continuous_coors, 'b (nf nv) c -> b nf nv c', nv=3, c=3)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    return continuous_coors # b, nf, 3, 3
         | 
| 92 | 
            +
             | 
| 93 | 
            +
             | 
| 94 | 
            +
                def forward(self, data_dict: dict, is_eval: bool = False) -> dict:
         | 
| 95 | 
            +
                    if not is_eval:
         | 
| 96 | 
            +
                        return self.train_one_step(data_dict)
         | 
| 97 | 
            +
                    else:
         | 
| 98 | 
            +
                        return self.generate(data_dict)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                def process_point_feature(self, point_feature):
         | 
| 101 | 
            +
                    encode_feature = torch.zeros(point_feature.shape[0], self.cond_length, self.config.word_embed_proj_dim,
         | 
| 102 | 
            +
                                                device=self.cond_head_proj.weight.device, dtype=self.cond_head_proj.weight.dtype)
         | 
| 103 | 
            +
                    encode_feature[:, 0] = self.cond_head_proj(point_feature[:, 0])
         | 
| 104 | 
            +
                    shape_latents = self.point_encoder.to_shape_latents(point_feature[:, 1:])
         | 
| 105 | 
            +
                    encode_feature[:, 1:] = self.cond_proj(torch.cat([point_feature[:, 1:], shape_latents], dim=-1))
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    return encode_feature
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                @torch.no_grad()
         | 
| 110 | 
            +
                def forward(self, pc_normal, sampling=False) -> dict:
         | 
| 111 | 
            +
                    batch_size = pc_normal.shape[0]
         | 
| 112 | 
            +
                    point_feature = self.point_encoder.encode_latents(pc_normal)
         | 
| 113 | 
            +
                    processed_point_feature = self.process_point_feature(point_feature)
         | 
| 114 | 
            +
                    generate_length = self.max_length - self.cond_length
         | 
| 115 | 
            +
                    net_device = next(self.parameters()).device
         | 
| 116 | 
            +
                    outputs = torch.ones(batch_size, generate_length).long().to(net_device) * self.eos_token_id
         | 
| 117 | 
            +
                    # batch x ntokens
         | 
| 118 | 
            +
                    if not sampling:
         | 
| 119 | 
            +
                        results = self.transformer.generate(
         | 
| 120 | 
            +
                            inputs_embeds=processed_point_feature,
         | 
| 121 | 
            +
                            max_new_tokens=generate_length,  # all faces plus two
         | 
| 122 | 
            +
                            num_beams=1,
         | 
| 123 | 
            +
                            bos_token_id=self.bos_token_id,
         | 
| 124 | 
            +
                            eos_token_id=self.eos_token_id,
         | 
| 125 | 
            +
                            pad_token_id=self.pad_token_id,
         | 
| 126 | 
            +
                        )
         | 
| 127 | 
            +
                    else:
         | 
| 128 | 
            +
                        results = self.transformer.generate(
         | 
| 129 | 
            +
                            inputs_embeds = processed_point_feature,
         | 
| 130 | 
            +
                            max_new_tokens = generate_length, # all faces plus two
         | 
| 131 | 
            +
                            do_sample=True,
         | 
| 132 | 
            +
                            top_k=50,
         | 
| 133 | 
            +
                            top_p=0.95,
         | 
| 134 | 
            +
                            bos_token_id = self.bos_token_id,
         | 
| 135 | 
            +
                            eos_token_id = self.eos_token_id,
         | 
| 136 | 
            +
                            pad_token_id = self.pad_token_id,
         | 
| 137 | 
            +
                        )
         | 
| 138 | 
            +
                    assert results.shape[1] <= generate_length # B x ID  bos is not included since it's predicted
         | 
| 139 | 
            +
                    outputs[:, :results.shape[1]] = results
         | 
| 140 | 
            +
                    # batch x ntokens ====> batch x ntokens x D
         | 
| 141 | 
            +
                    outputs = outputs[:, 1: -1]
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    outputs[outputs == self.bos_token_id] = self.pad_id
         | 
| 144 | 
            +
                    outputs[outputs == self.eos_token_id] = self.pad_id
         | 
| 145 | 
            +
                    outputs[outputs == self.pad_token_id] = self.pad_id
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    outputs[outputs != self.pad_id] -= 3
         | 
| 148 | 
            +
                    gen_mesh = self.adjacent_detokenize(outputs)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    return gen_mesh
         | 
| 151 | 
            +
             | 
| 152 | 
            +
            def undiscretize(
         | 
| 153 | 
            +
                t,
         | 
| 154 | 
            +
                low,#-0.5
         | 
| 155 | 
            +
                high,# 0.5
         | 
| 156 | 
            +
                num_discrete
         | 
| 157 | 
            +
            ):
         | 
| 158 | 
            +
                t = t.float() #[0, num_discrete-1]
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                t /= num_discrete  # 0<=t<1
         | 
| 161 | 
            +
                t = t * (high - low) + low # -0.5 <= t < 0.5
         | 
| 162 | 
            +
                return t
         | 
    	
        MeshAnything/models/shape_opt.py
    CHANGED
    
    | @@ -8,9 +8,8 @@ from transformers.modeling_outputs import ( | |
| 8 | 
             
            import torch
         | 
| 9 | 
             
            from torch import nn
         | 
| 10 | 
             
            from torch.nn import CrossEntropyLoss
         | 
| 11 | 
            -
            from transformers.utils import replace_return_docstrings | 
| 12 | 
             
            from transformers.modeling_outputs import BaseModelOutputWithPast
         | 
| 13 | 
            -
            # from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
         | 
| 14 |  | 
| 15 | 
             
            class ShapeOPTConfig(OPTConfig):
         | 
| 16 | 
             
                model_type = "shape_opt"
         | 
| @@ -26,23 +25,6 @@ class ShapeOPT(OPTForCausalLM): | |
| 26 | 
             
                    # Initialize weights and apply final processing
         | 
| 27 | 
             
                    self.post_init()
         | 
| 28 |  | 
| 29 | 
            -
                def tie_weights(self):
         | 
| 30 | 
            -
                    """
         | 
| 31 | 
            -
                    Tie the weights between the input embeddings and the output embeddings.
         | 
| 32 | 
            -
             | 
| 33 | 
            -
                    If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
         | 
| 34 | 
            -
                    weights instead.
         | 
| 35 | 
            -
                    """
         | 
| 36 | 
            -
                    if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):
         | 
| 37 | 
            -
                        if hasattr(self, self.base_model_prefix):
         | 
| 38 | 
            -
                            self = getattr(self, self.base_model_prefix)
         | 
| 39 | 
            -
                        self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)
         | 
| 40 | 
            -
             | 
| 41 | 
            -
                    for module in self.modules():
         | 
| 42 | 
            -
                        if hasattr(module, "_tie_weights"):
         | 
| 43 | 
            -
                            module._tie_weights()
         | 
| 44 | 
            -
             | 
| 45 | 
            -
             | 
| 46 | 
             
                @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="OPTConfig")
         | 
| 47 | 
             
                def forward(
         | 
| 48 | 
             
                    self,
         | 
| @@ -140,7 +122,7 @@ class ShapeOPT(OPTForCausalLM): | |
| 140 |  | 
| 141 | 
             
                    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
         | 
| 142 | 
             
                    outputs = self.model.decoder(
         | 
| 143 | 
            -
                        input_ids=input_ids,
         | 
| 144 | 
             
                        face_ids = face_ids,
         | 
| 145 | 
             
                        attention_mask=attention_mask,
         | 
| 146 | 
             
                        head_mask=head_mask,
         | 
| @@ -195,28 +177,18 @@ class ShapeOPTDecoder(OPTDecoder): | |
| 195 | 
             
                    self.padding_idx = config.pad_token_id
         | 
| 196 | 
             
                    self.max_target_positions = config.max_position_embeddings
         | 
| 197 | 
             
                    self.vocab_size = config.vocab_size
         | 
| 198 | 
            -
             | 
| 199 | 
            -
                    self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx) # not used
         | 
| 200 | 
             
                    self.hidden_size = config.hidden_size
         | 
| 201 | 
             
                    self.word_embed_proj_dim = config.word_embed_proj_dim
         | 
| 202 | 
            -
                    self. | 
| 203 | 
            -
                    self.input_layer = nn.Linear(config.quantize_codebook_dim, config.word_embed_proj_dim)
         | 
| 204 |  | 
| 205 | 
             
                    self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)
         | 
| 206 | 
            -
                    self.token_embed_positions = | 
|  | |
| 207 | 
             
                    self.face_per_token = config.face_per_token
         | 
| 208 | 
             
                    self.cond_length = config.cond_length
         | 
| 209 | 
             
                    self.cond_embed = nn.Embedding(2, config.word_embed_proj_dim)
         | 
| 210 |  | 
| 211 | 
            -
                    if config.word_embed_proj_dim != config.hidden_size:
         | 
| 212 | 
            -
                        self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
         | 
| 213 | 
            -
                    else:
         | 
| 214 | 
            -
                        self.project_out = None
         | 
| 215 | 
            -
             | 
| 216 | 
            -
                    if config.word_embed_proj_dim != config.hidden_size:
         | 
| 217 | 
            -
                        self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
         | 
| 218 | 
            -
                    else:
         | 
| 219 | 
            -
                        self.project_in = None
         | 
| 220 | 
             
                    # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
         | 
| 221 | 
             
                    # with checkpoints that have been fine-tuned before transformers v4.20.1
         | 
| 222 | 
             
                    # see https://github.com/facebookresearch/metaseq/pull/164
         | 
| @@ -234,17 +206,6 @@ class ShapeOPTDecoder(OPTDecoder): | |
| 234 | 
             
                    # Initialize weights and apply final processing
         | 
| 235 | 
             
                    self.post_init()
         | 
| 236 |  | 
| 237 | 
            -
                def embed_with_vae(self, input_ids):
         | 
| 238 | 
            -
                    inputs_embeds = repeat(torch.zeros(input_ids.shape, device=input_ids.device), 'b n -> b n d',
         | 
| 239 | 
            -
                                           d=self.word_embed_proj_dim).clone().detach()
         | 
| 240 | 
            -
                    idx_in_extra = torch.isin(input_ids, torch.LongTensor([0, 1, 2]).to(input_ids.device))
         | 
| 241 | 
            -
                    inputs_embeds[idx_in_extra] += self.extra_embeds(input_ids[idx_in_extra])
         | 
| 242 | 
            -
                    self.quantize_codebooks = self.quantize_codebooks.to(input_ids.device)
         | 
| 243 | 
            -
                    inputs_embeds[~idx_in_extra] += self.input_layer(self.quantize_codebooks[0][input_ids[~idx_in_extra] - 3])
         | 
| 244 | 
            -
             | 
| 245 | 
            -
                    return inputs_embeds
         | 
| 246 | 
            -
             | 
| 247 | 
            -
             | 
| 248 | 
             
                def forward(
         | 
| 249 | 
             
                    self,
         | 
| 250 | 
             
                    input_ids: torch.LongTensor = None,
         | 
| @@ -315,11 +276,13 @@ class ShapeOPTDecoder(OPTDecoder): | |
| 315 |  | 
| 316 | 
             
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         | 
| 317 | 
             
                    # Transformer Decoder
         | 
| 318 | 
            -
                    if input_ids is not None:
         | 
|  | |
|  | |
|  | |
| 319 | 
             
                        input_shape = input_ids.size()
         | 
| 320 | 
             
                        input_ids = input_ids.view(-1, input_shape[-1])
         | 
| 321 | 
            -
                        inputs_embeds = self. | 
| 322 | 
            -
             | 
| 323 | 
             
                        face_embeds = self.token_embed_positions(attention_mask[:, self.cond_length:], face_ids, input_ids,
         | 
| 324 | 
             
                                                                 self.face_per_token)
         | 
| 325 | 
             
                        inputs_embeds += face_embeds
         | 
| @@ -329,7 +292,8 @@ class ShapeOPTDecoder(OPTDecoder): | |
| 329 |  | 
| 330 | 
             
                    elif inputs_embeds is not None:
         | 
| 331 | 
             
                        # assert self.cond and not self.training
         | 
| 332 | 
            -
             | 
|  | |
| 333 | 
             
                        total_length = inputs_embeds.shape[1] # B x length x embeding
         | 
| 334 | 
             
                        cond_embed_query = torch.zeros((inputs_embeds.shape[0], total_length), device=inputs_embeds.device,
         | 
| 335 | 
             
                                                        dtype=inputs_embeds.dtype).long()
         | 
| @@ -357,9 +321,6 @@ class ShapeOPTDecoder(OPTDecoder): | |
| 357 |  | 
| 358 | 
             
                    pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
         | 
| 359 |  | 
| 360 | 
            -
                    if self.project_in is not None:
         | 
| 361 | 
            -
                        inputs_embeds = self.project_in(inputs_embeds)
         | 
| 362 | 
            -
             | 
| 363 | 
             
                    hidden_states = inputs_embeds + pos_embeds
         | 
| 364 |  | 
| 365 | 
             
                    # decoder layers
         | 
| @@ -419,9 +380,6 @@ class ShapeOPTDecoder(OPTDecoder): | |
| 419 | 
             
                    if self.final_layer_norm is not None:
         | 
| 420 | 
             
                        hidden_states = self.final_layer_norm(hidden_states)
         | 
| 421 |  | 
| 422 | 
            -
                    if self.project_out is not None:
         | 
| 423 | 
            -
                        hidden_states = self.project_out(hidden_states)
         | 
| 424 | 
            -
             | 
| 425 | 
             
                    # add hidden states from the last decoder layer
         | 
| 426 | 
             
                    if output_hidden_states:
         | 
| 427 | 
             
                        all_hidden_states += (hidden_states,)
         | 
| @@ -436,6 +394,56 @@ class ShapeOPTDecoder(OPTDecoder): | |
| 436 | 
             
                        attentions=all_self_attns,
         | 
| 437 | 
             
                    )
         | 
| 438 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 439 | 
             
            class OPTFacePositionalEmbedding(nn.Embedding):
         | 
| 440 | 
             
                """
         | 
| 441 | 
             
                This module learns positional embeddings up to a fixed maximum size.
         | 
|  | |
| 8 | 
             
            import torch
         | 
| 9 | 
             
            from torch import nn
         | 
| 10 | 
             
            from torch.nn import CrossEntropyLoss
         | 
| 11 | 
            +
            from transformers.utils import replace_return_docstrings
         | 
| 12 | 
             
            from transformers.modeling_outputs import BaseModelOutputWithPast
         | 
|  | |
| 13 |  | 
| 14 | 
             
            class ShapeOPTConfig(OPTConfig):
         | 
| 15 | 
             
                model_type = "shape_opt"
         | 
|  | |
| 25 | 
             
                    # Initialize weights and apply final processing
         | 
| 26 | 
             
                    self.post_init()
         | 
| 27 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 28 | 
             
                @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="OPTConfig")
         | 
| 29 | 
             
                def forward(
         | 
| 30 | 
             
                    self,
         | 
|  | |
| 122 |  | 
| 123 | 
             
                    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
         | 
| 124 | 
             
                    outputs = self.model.decoder(
         | 
| 125 | 
            +
                        input_ids = input_ids,
         | 
| 126 | 
             
                        face_ids = face_ids,
         | 
| 127 | 
             
                        attention_mask=attention_mask,
         | 
| 128 | 
             
                        head_mask=head_mask,
         | 
|  | |
| 177 | 
             
                    self.padding_idx = config.pad_token_id
         | 
| 178 | 
             
                    self.max_target_positions = config.max_position_embeddings
         | 
| 179 | 
             
                    self.vocab_size = config.vocab_size
         | 
| 180 | 
            +
                    self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx)
         | 
|  | |
| 181 | 
             
                    self.hidden_size = config.hidden_size
         | 
| 182 | 
             
                    self.word_embed_proj_dim = config.word_embed_proj_dim
         | 
| 183 | 
            +
                    self.n_discrete_size = config.n_discrete_size
         | 
|  | |
| 184 |  | 
| 185 | 
             
                    self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)
         | 
| 186 | 
            +
                    self.token_embed_positions = OPTLoopEmbedding(10, config.word_embed_proj_dim, self.n_discrete_size) #padding_idx=self.padding_idx)
         | 
| 187 | 
            +
             | 
| 188 | 
             
                    self.face_per_token = config.face_per_token
         | 
| 189 | 
             
                    self.cond_length = config.cond_length
         | 
| 190 | 
             
                    self.cond_embed = nn.Embedding(2, config.word_embed_proj_dim)
         | 
| 191 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 192 | 
             
                    # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
         | 
| 193 | 
             
                    # with checkpoints that have been fine-tuned before transformers v4.20.1
         | 
| 194 | 
             
                    # see https://github.com/facebookresearch/metaseq/pull/164
         | 
|  | |
| 206 | 
             
                    # Initialize weights and apply final processing
         | 
| 207 | 
             
                    self.post_init()
         | 
| 208 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 209 | 
             
                def forward(
         | 
| 210 | 
             
                    self,
         | 
| 211 | 
             
                    input_ids: torch.LongTensor = None,
         | 
|  | |
| 276 |  | 
| 277 | 
             
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         | 
| 278 | 
             
                    # Transformer Decoder
         | 
| 279 | 
            +
                    if input_ids is not None and inputs_embeds is not None: # when  train and first generate
         | 
| 280 | 
            +
                        assert False
         | 
| 281 | 
            +
                    elif input_ids is not None:
         | 
| 282 | 
            +
                        assert not self.training
         | 
| 283 | 
             
                        input_shape = input_ids.size()
         | 
| 284 | 
             
                        input_ids = input_ids.view(-1, input_shape[-1])
         | 
| 285 | 
            +
                        inputs_embeds = self.embed_tokens(input_ids)
         | 
|  | |
| 286 | 
             
                        face_embeds = self.token_embed_positions(attention_mask[:, self.cond_length:], face_ids, input_ids,
         | 
| 287 | 
             
                                                                 self.face_per_token)
         | 
| 288 | 
             
                        inputs_embeds += face_embeds
         | 
|  | |
| 292 |  | 
| 293 | 
             
                    elif inputs_embeds is not None:
         | 
| 294 | 
             
                        # assert self.cond and not self.training
         | 
| 295 | 
            +
                        assert not self.training
         | 
| 296 | 
            +
                        self.token_embed_positions.init_state(inputs_embeds)
         | 
| 297 | 
             
                        total_length = inputs_embeds.shape[1] # B x length x embeding
         | 
| 298 | 
             
                        cond_embed_query = torch.zeros((inputs_embeds.shape[0], total_length), device=inputs_embeds.device,
         | 
| 299 | 
             
                                                        dtype=inputs_embeds.dtype).long()
         | 
|  | |
| 321 |  | 
| 322 | 
             
                    pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
         | 
| 323 |  | 
|  | |
|  | |
|  | |
| 324 | 
             
                    hidden_states = inputs_embeds + pos_embeds
         | 
| 325 |  | 
| 326 | 
             
                    # decoder layers
         | 
|  | |
| 380 | 
             
                    if self.final_layer_norm is not None:
         | 
| 381 | 
             
                        hidden_states = self.final_layer_norm(hidden_states)
         | 
| 382 |  | 
|  | |
|  | |
|  | |
| 383 | 
             
                    # add hidden states from the last decoder layer
         | 
| 384 | 
             
                    if output_hidden_states:
         | 
| 385 | 
             
                        all_hidden_states += (hidden_states,)
         | 
|  | |
| 394 | 
             
                        attentions=all_self_attns,
         | 
| 395 | 
             
                    )
         | 
| 396 |  | 
| 397 | 
            +
            class OPTLoopEmbedding(nn.Embedding):
         | 
| 398 | 
            +
                """
         | 
| 399 | 
            +
                This module learns positional embeddings up to a fixed maximum size.
         | 
| 400 | 
            +
                """
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                def __init__(self, num_embeddings: int, embedding_dim: int, n_discrete_size: int):
         | 
| 403 | 
            +
                    super().__init__(num_embeddings, embedding_dim)
         | 
| 404 | 
            +
                    self.state = None
         | 
| 405 | 
            +
                    self.loop_state = None
         | 
| 406 | 
            +
                    self.n_discrete_size = n_discrete_size + 3 # for padding
         | 
| 407 | 
            +
             | 
| 408 | 
            +
                def forward(self, attention_mask=None, face_ids = None, input_ids = None, face_per_token = None):
         | 
| 409 | 
            +
                    """`input_ids_shape` is expected to be [bsz x seqlen]."""
         | 
| 410 | 
            +
                    if face_ids is not None:
         | 
| 411 | 
            +
                        return super().forward(face_ids)
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                    assert input_ids.shape[1] == 1, "Only one token is allowed for loop embedding"
         | 
| 414 | 
            +
                    assert self.state is not None, "State is not initialized"
         | 
| 415 | 
            +
                    # zero as beginning
         | 
| 416 | 
            +
                    batch_size = input_ids.shape[0]
         | 
| 417 | 
            +
                    face_ids = input_ids.clone().detach()
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                    for cur_batch_index in range(batch_size):
         | 
| 420 | 
            +
                        cur_ids = input_ids[cur_batch_index]
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                        idx_in_extra = torch.isin(cur_ids, torch.LongTensor([0, 1, 2]).to(input_ids.device))
         | 
| 423 | 
            +
                        if idx_in_extra:
         | 
| 424 | 
            +
                            self.state[cur_batch_index] = 9  # init
         | 
| 425 | 
            +
                            self.loop_state[cur_batch_index] = 0
         | 
| 426 | 
            +
                        else:
         | 
| 427 | 
            +
                            if cur_ids == self.n_discrete_size:
         | 
| 428 | 
            +
                                face_ids[cur_batch_index] = 3
         | 
| 429 | 
            +
                                self.state[cur_batch_index] = 9 # init
         | 
| 430 | 
            +
                                self.loop_state[cur_batch_index] = 0
         | 
| 431 | 
            +
                            else:
         | 
| 432 | 
            +
                                if self.state[cur_batch_index] == 0:
         | 
| 433 | 
            +
                                    face_ids[cur_batch_index] = 7 + self.loop_state[cur_batch_index] % 3
         | 
| 434 | 
            +
                                else:
         | 
| 435 | 
            +
                                    self.state[cur_batch_index] -= 1
         | 
| 436 | 
            +
                                    face_ids[cur_batch_index] = 4 + self.loop_state[cur_batch_index] % 3
         | 
| 437 | 
            +
                                self.loop_state[cur_batch_index] += 1
         | 
| 438 | 
            +
             | 
| 439 | 
            +
                    return super().forward(face_ids)
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                def init_state(self, template_tensor):
         | 
| 442 | 
            +
                    batch_size = template_tensor.shape[0]
         | 
| 443 | 
            +
                    self.state = torch.zeros((batch_size, 1), dtype=torch.long, device=template_tensor.device)
         | 
| 444 | 
            +
                    self.state[...] = 9
         | 
| 445 | 
            +
                    self.loop_state = torch.zeros((batch_size, 1), dtype=torch.long, device=template_tensor.device)
         | 
| 446 | 
            +
             | 
| 447 | 
             
            class OPTFacePositionalEmbedding(nn.Embedding):
         | 
| 448 | 
             
                """
         | 
| 449 | 
             
                This module learns positional embeddings up to a fixed maximum size.
         |