Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from promptda.model.dpt import DPTHead | |
| from promptda.model.config import model_configs | |
| from promptda.utils.logger import Log | |
| import os | |
| from pathlib import Path | |
| from huggingface_hub import hf_hub_download | |
| class PromptDA(nn.Module): | |
| patch_size = 14 # patch size of the pretrained dinov2 model | |
| use_bn = False | |
| use_clstoken = False | |
| output_act = 'sigmoid' | |
| def __init__(self, | |
| encoder='vitl', | |
| ckpt_path='data/checkpoints/promptda_vitl.ckpt'): | |
| super().__init__() | |
| model_config = model_configs[encoder] | |
| self.encoder = encoder | |
| self.model_config = model_config | |
| self.pretrained = torch.hub.load( | |
| 'torchhub/facebookresearch_dinov2_main', | |
| 'dinov2_{:}14'.format(encoder), | |
| source='local', | |
| pretrained=False) | |
| dim = self.pretrained.blocks[0].attn.qkv.in_features | |
| self.depth_head = DPTHead(nclass=1, | |
| in_channels=dim, | |
| features=model_config['features'], | |
| out_channels=model_config['out_channels'], | |
| use_bn=self.use_bn, | |
| use_clstoken=self.use_clstoken, | |
| output_act=self.output_act) | |
| # mean and std of the pretrained dinov2 model | |
| self.register_buffer('_mean', torch.tensor( | |
| [0.485, 0.456, 0.406]).view(1, 3, 1, 1)) | |
| self.register_buffer('_std', torch.tensor( | |
| [0.229, 0.224, 0.225]).view(1, 3, 1, 1)) | |
| self.load_checkpoint(ckpt_path) | |
| def from_pretrained(cls, pretrained_model_name_or_path = None, model_kwargs = None, **hf_kwargs): | |
| """ | |
| Load a model from a checkpoint file. | |
| ### Parameters: | |
| - `pretrained_model_name_or_path`: path to the checkpoint file or repo id. | |
| - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint. | |
| - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path. | |
| ### Returns: | |
| - A new instance of `MoGe` with the parameters loaded from the checkpoint. | |
| """ | |
| ckpt_path = None | |
| if Path(pretrained_model_name_or_path).exists(): | |
| ckpt_path = pretrained_model_name_or_path | |
| else: | |
| cached_checkpoint_path = hf_hub_download( | |
| repo_id=pretrained_model_name_or_path, | |
| repo_type="model", | |
| filename="model.ckpt", | |
| **hf_kwargs | |
| ) | |
| ckpt_path = cached_checkpoint_path | |
| # model_config = checkpoint['model_config'] | |
| # if model_kwargs is not None: | |
| # model_config.update(model_kwargs) | |
| if model_kwargs is None: | |
| model_kwargs = {} | |
| model_kwargs.update({'ckpt_path': ckpt_path}) | |
| model = cls(**model_kwargs) | |
| return model | |
| def load_checkpoint(self, ckpt_path): | |
| if os.path.exists(ckpt_path): | |
| Log.info(f'Loading checkpoint from {ckpt_path}') | |
| checkpoint = torch.load(ckpt_path, map_location='cpu') | |
| self.load_state_dict( | |
| {k[9:]: v for k, v in checkpoint['state_dict'].items()}) | |
| else: | |
| Log.warn(f'Checkpoint {ckpt_path} not found') | |
| def forward(self, x, prompt_depth=None): | |
| assert prompt_depth is not None, 'prompt_depth is required' | |
| prompt_depth, min_val, max_val = self.normalize(prompt_depth) | |
| h, w = x.shape[-2:] | |
| features = self.pretrained.get_intermediate_layers( | |
| x, self.model_config['layer_idxs'], | |
| return_class_token=True) | |
| patch_h, patch_w = h // self.patch_size, w // self.patch_size | |
| depth = self.depth_head(features, patch_h, patch_w, prompt_depth) | |
| depth = self.denormalize(depth, min_val, max_val) | |
| return depth | |
| def predict(self, | |
| image: torch.Tensor, | |
| prompt_depth: torch.Tensor): | |
| return self.forward(image, prompt_depth) | |
| def normalize(self, | |
| prompt_depth: torch.Tensor): | |
| B, C, H, W = prompt_depth.shape | |
| min_val = torch.quantile( | |
| prompt_depth.reshape(B, -1), 0., dim=1, keepdim=True)[:, :, None, None] | |
| max_val = torch.quantile( | |
| prompt_depth.reshape(B, -1), 1., dim=1, keepdim=True)[:, :, None, None] | |
| prompt_depth = (prompt_depth - min_val) / (max_val - min_val) | |
| return prompt_depth, min_val, max_val | |
| def denormalize(self, | |
| depth: torch.Tensor, | |
| min_val: torch.Tensor, | |
| max_val: torch.Tensor): | |
| return depth * (max_val - min_val) + min_val | |