|  |  | 
					
						
						|  |  | 
					
						
						|  | import logging | 
					
						
						|  | import numpy as np | 
					
						
						|  | import pickle | 
					
						
						|  | from enum import Enum | 
					
						
						|  | from typing import Optional | 
					
						
						|  | import torch | 
					
						
						|  | from torch import nn | 
					
						
						|  |  | 
					
						
						|  | from detectron2.config import CfgNode | 
					
						
						|  | from detectron2.utils.file_io import PathManager | 
					
						
						|  |  | 
					
						
						|  | from .vertex_direct_embedder import VertexDirectEmbedder | 
					
						
						|  | from .vertex_feature_embedder import VertexFeatureEmbedder | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class EmbedderType(Enum): | 
					
						
						|  | """ | 
					
						
						|  | Embedder type which defines how vertices are mapped into the embedding space: | 
					
						
						|  | - "vertex_direct": direct vertex embedding | 
					
						
						|  | - "vertex_feature": embedding vertex features | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | VERTEX_DIRECT = "vertex_direct" | 
					
						
						|  | VERTEX_FEATURE = "vertex_feature" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def create_embedder(embedder_spec: CfgNode, embedder_dim: int) -> nn.Module: | 
					
						
						|  | """ | 
					
						
						|  | Create an embedder based on the provided configuration | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | embedder_spec (CfgNode): embedder configuration | 
					
						
						|  | embedder_dim (int): embedding space dimensionality | 
					
						
						|  | Return: | 
					
						
						|  | An embedder instance for the specified configuration | 
					
						
						|  | Raises ValueError, in case of unexpected  embedder type | 
					
						
						|  | """ | 
					
						
						|  | embedder_type = EmbedderType(embedder_spec.TYPE) | 
					
						
						|  | if embedder_type == EmbedderType.VERTEX_DIRECT: | 
					
						
						|  | embedder = VertexDirectEmbedder( | 
					
						
						|  | num_vertices=embedder_spec.NUM_VERTICES, | 
					
						
						|  | embed_dim=embedder_dim, | 
					
						
						|  | ) | 
					
						
						|  | if embedder_spec.INIT_FILE != "": | 
					
						
						|  | embedder.load(embedder_spec.INIT_FILE) | 
					
						
						|  | elif embedder_type == EmbedderType.VERTEX_FEATURE: | 
					
						
						|  | embedder = VertexFeatureEmbedder( | 
					
						
						|  | num_vertices=embedder_spec.NUM_VERTICES, | 
					
						
						|  | feature_dim=embedder_spec.FEATURE_DIM, | 
					
						
						|  | embed_dim=embedder_dim, | 
					
						
						|  | train_features=embedder_spec.FEATURES_TRAINABLE, | 
					
						
						|  | ) | 
					
						
						|  | if embedder_spec.INIT_FILE != "": | 
					
						
						|  | embedder.load(embedder_spec.INIT_FILE) | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError(f"Unexpected embedder type {embedder_type}") | 
					
						
						|  |  | 
					
						
						|  | if not embedder_spec.IS_TRAINABLE: | 
					
						
						|  | embedder.requires_grad_(False) | 
					
						
						|  |  | 
					
						
						|  | return embedder | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Embedder(nn.Module): | 
					
						
						|  | """ | 
					
						
						|  | Embedder module that serves as a container for embedders to use with different | 
					
						
						|  | meshes. Extends Module to automatically save / load state dict. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | DEFAULT_MODEL_CHECKPOINT_PREFIX = "roi_heads.embedder." | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, cfg: CfgNode): | 
					
						
						|  | """ | 
					
						
						|  | Initialize mesh embedders. An embedder for mesh `i` is stored in a submodule | 
					
						
						|  | "embedder_{i}". | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | cfg (CfgNode): configuration options | 
					
						
						|  | """ | 
					
						
						|  | super(Embedder, self).__init__() | 
					
						
						|  | self.mesh_names = set() | 
					
						
						|  | embedder_dim = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE | 
					
						
						|  | logger = logging.getLogger(__name__) | 
					
						
						|  | for mesh_name, embedder_spec in cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS.items(): | 
					
						
						|  | logger.info(f"Adding embedder embedder_{mesh_name} with spec {embedder_spec}") | 
					
						
						|  | self.add_module(f"embedder_{mesh_name}", create_embedder(embedder_spec, embedder_dim)) | 
					
						
						|  | self.mesh_names.add(mesh_name) | 
					
						
						|  | if cfg.MODEL.WEIGHTS != "": | 
					
						
						|  | self.load_from_model_checkpoint(cfg.MODEL.WEIGHTS) | 
					
						
						|  |  | 
					
						
						|  | def load_from_model_checkpoint(self, fpath: str, prefix: Optional[str] = None): | 
					
						
						|  | if prefix is None: | 
					
						
						|  | prefix = Embedder.DEFAULT_MODEL_CHECKPOINT_PREFIX | 
					
						
						|  | state_dict = None | 
					
						
						|  | if fpath.endswith(".pkl"): | 
					
						
						|  | with PathManager.open(fpath, "rb") as hFile: | 
					
						
						|  | state_dict = pickle.load(hFile, encoding="latin1") | 
					
						
						|  | else: | 
					
						
						|  | with PathManager.open(fpath, "rb") as hFile: | 
					
						
						|  | state_dict = torch.load(hFile, map_location=torch.device("cpu")) | 
					
						
						|  | if state_dict is not None and "model" in state_dict: | 
					
						
						|  | state_dict_local = {} | 
					
						
						|  | for key in state_dict["model"]: | 
					
						
						|  | if key.startswith(prefix): | 
					
						
						|  | v_key = state_dict["model"][key] | 
					
						
						|  | if isinstance(v_key, np.ndarray): | 
					
						
						|  | v_key = torch.from_numpy(v_key) | 
					
						
						|  | state_dict_local[key[len(prefix) :]] = v_key | 
					
						
						|  |  | 
					
						
						|  | self.load_state_dict(state_dict_local, strict=False) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, mesh_name: str) -> torch.Tensor: | 
					
						
						|  | """ | 
					
						
						|  | Produce vertex embeddings for the specific mesh; vertex embeddings are | 
					
						
						|  | a tensor of shape [N, D] where: | 
					
						
						|  | N = number of vertices | 
					
						
						|  | D = number of dimensions in the embedding space | 
					
						
						|  | Args: | 
					
						
						|  | mesh_name (str): name of a mesh for which to obtain vertex embeddings | 
					
						
						|  | Return: | 
					
						
						|  | Vertex embeddings, a tensor of shape [N, D] | 
					
						
						|  | """ | 
					
						
						|  | return getattr(self, f"embedder_{mesh_name}")() | 
					
						
						|  |  | 
					
						
						|  | def has_embeddings(self, mesh_name: str) -> bool: | 
					
						
						|  | return hasattr(self, f"embedder_{mesh_name}") | 
					
						
						|  |  |