Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn.functional as nnf | |
| from torch import nn | |
| import random | |
| from transformers import AutoModelForCausalLM | |
| from MeshAnything.miche.encode import load_model | |
| from MeshAnything.models.shape_opt import ShapeOPTConfig | |
| from einops import repeat, reduce, rearrange, pack, unpack | |
| class MeshAnythingV2(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.point_encoder = load_model(ckpt_path=None) | |
| self.n_discrete_size = 128 | |
| self.max_seq_ratio = 0.70 | |
| self.face_per_token = 9 | |
| self.cond_length = 257 | |
| self.cond_dim = 768 | |
| self.pad_id = -1 | |
| self.n_max_triangles = 1600 | |
| self.max_length = int(self.n_max_triangles * self.face_per_token * self.max_seq_ratio + 3 + self.cond_length) # add 1 | |
| self.coor_continuous_range = (-0.5, 0.5) | |
| self.config = ShapeOPTConfig.from_pretrained( | |
| "facebook/opt-350m", | |
| n_positions=self.max_length, | |
| max_position_embeddings=self.max_length, | |
| vocab_size=self.n_discrete_size + 4, | |
| _attn_implementation="flash_attention_2" | |
| ) | |
| self.bos_token_id = 0 | |
| self.eos_token_id = 1 | |
| self.pad_token_id = 2 | |
| self.config.bos_token_id = self.bos_token_id | |
| self.config.eos_token_id = self.eos_token_id | |
| self.config.pad_token_id = self.pad_token_id | |
| self.config._attn_implementation="flash_attention_2" | |
| self.config.n_discrete_size = self.n_discrete_size | |
| self.config.face_per_token = self.face_per_token | |
| self.config.cond_length = self.cond_length | |
| if self.config.word_embed_proj_dim != self.config.hidden_size: | |
| self.config.word_embed_proj_dim = self.config.hidden_size | |
| self.transformer = AutoModelForCausalLM.from_config( | |
| config=self.config, use_flash_attention_2 = True | |
| ) | |
| self.transformer.to_bettertransformer() | |
| self.cond_head_proj = nn.Linear(self.cond_dim, self.config.word_embed_proj_dim) | |
| self.cond_proj = nn.Linear(self.cond_dim * 2, self.config.word_embed_proj_dim) | |
| self.eval() | |
| def adjacent_detokenize(self, input_ids): | |
| input_ids = input_ids.reshape(input_ids.shape[0], -1) # B x L | |
| batch_size = input_ids.shape[0] | |
| continuous_coors = torch.zeros((batch_size, self.n_max_triangles * 3 * 10, 3), device=input_ids.device) | |
| continuous_coors[...] = float('nan') | |
| for i in range(batch_size): | |
| cur_ids = input_ids[i] | |
| coor_loop_check = 0 | |
| vertice_count = 0 | |
| continuous_coors[i, :3, :] = torch.tensor([[-0.1, 0.0, 0.1], [-0.1, 0.1, 0.2], [-0.3, 0.3, 0.2]], | |
| device=input_ids.device) | |
| for id in cur_ids: | |
| if id == self.pad_id: | |
| break | |
| elif id == self.n_discrete_size: | |
| if coor_loop_check < 9: | |
| break | |
| if coor_loop_check % 3 !=0: | |
| break | |
| coor_loop_check = 0 | |
| else: | |
| if coor_loop_check % 3 == 0 and coor_loop_check >= 9: | |
| continuous_coors[i, vertice_count] = continuous_coors[i, vertice_count-2] | |
| continuous_coors[i, vertice_count+1] = continuous_coors[i, vertice_count-1] | |
| vertice_count += 2 | |
| continuous_coors[i, vertice_count, coor_loop_check % 3] = undiscretize(id, self.coor_continuous_range[0], self.coor_continuous_range[1], self.n_discrete_size) | |
| if coor_loop_check % 3 == 2: | |
| vertice_count += 1 | |
| coor_loop_check += 1 | |
| continuous_coors = rearrange(continuous_coors, 'b (nf nv) c -> b nf nv c', nv=3, c=3) | |
| return continuous_coors # b, nf, 3, 3 | |
| def forward(self, data_dict: dict, is_eval: bool = False) -> dict: | |
| if not is_eval: | |
| return self.train_one_step(data_dict) | |
| else: | |
| return self.generate(data_dict) | |
| def process_point_feature(self, point_feature): | |
| encode_feature = torch.zeros(point_feature.shape[0], self.cond_length, self.config.word_embed_proj_dim, | |
| device=self.cond_head_proj.weight.device, dtype=self.cond_head_proj.weight.dtype) | |
| encode_feature[:, 0] = self.cond_head_proj(point_feature[:, 0]) | |
| shape_latents = self.point_encoder.to_shape_latents(point_feature[:, 1:]) | |
| encode_feature[:, 1:] = self.cond_proj(torch.cat([point_feature[:, 1:], shape_latents], dim=-1)) | |
| return encode_feature | |
| def forward(self, pc_normal, sampling=False) -> dict: | |
| batch_size = pc_normal.shape[0] | |
| point_feature = self.point_encoder.encode_latents(pc_normal) | |
| processed_point_feature = self.process_point_feature(point_feature) | |
| generate_length = self.max_length - self.cond_length | |
| net_device = next(self.parameters()).device | |
| outputs = torch.ones(batch_size, generate_length).long().to(net_device) * self.eos_token_id | |
| # batch x ntokens | |
| if not sampling: | |
| results = self.transformer.generate( | |
| inputs_embeds=processed_point_feature, | |
| max_new_tokens=generate_length, # all faces plus two | |
| num_beams=1, | |
| bos_token_id=self.bos_token_id, | |
| eos_token_id=self.eos_token_id, | |
| pad_token_id=self.pad_token_id, | |
| ) | |
| else: | |
| results = self.transformer.generate( | |
| inputs_embeds = processed_point_feature, | |
| max_new_tokens = generate_length, # all faces plus two | |
| do_sample=True, | |
| top_k=50, | |
| top_p=0.95, | |
| bos_token_id = self.bos_token_id, | |
| eos_token_id = self.eos_token_id, | |
| pad_token_id = self.pad_token_id, | |
| ) | |
| assert results.shape[1] <= generate_length # B x ID bos is not included since it's predicted | |
| outputs[:, :results.shape[1]] = results | |
| # batch x ntokens ====> batch x ntokens x D | |
| outputs = outputs[:, 1: -1] | |
| outputs[outputs == self.bos_token_id] = self.pad_id | |
| outputs[outputs == self.eos_token_id] = self.pad_id | |
| outputs[outputs == self.pad_token_id] = self.pad_id | |
| outputs[outputs != self.pad_id] -= 3 | |
| gen_mesh = self.adjacent_detokenize(outputs) | |
| return gen_mesh | |
| def undiscretize( | |
| t, | |
| low,#-0.5 | |
| high,# 0.5 | |
| num_discrete | |
| ): | |
| t = t.float() #[0, num_discrete-1] | |
| t /= num_discrete # 0<=t<1 | |
| t = t * (high - low) + low # -0.5 <= t < 0.5 | |
| return t | |