Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # Copyright (c) Meta Platforms, Inc. All Rights Reserved | |
| import torch | |
| from torch.nn import functional as F | |
| import numpy as np | |
| try: | |
| import pydensecrf.densecrf as dcrf | |
| from pydensecrf.utils import ( | |
| unary_from_softmax, | |
| unary_from_labels, | |
| create_pairwise_bilateral, | |
| create_pairwise_gaussian, | |
| ) | |
| except: | |
| dcrf = None | |
| def dense_crf_post_process( | |
| logits, | |
| image, | |
| n_labels=None, | |
| max_iters=5, | |
| pos_xy_std=(3, 3), | |
| pos_w=3, | |
| bi_xy_std=(80, 80), | |
| bi_rgb_std=(13, 13, 13), | |
| bi_w=10, | |
| ): | |
| """ | |
| logits : [C,H,W] | |
| image : [3,H,W] | |
| """ | |
| if dcrf is None: | |
| raise FileNotFoundError( | |
| "pydensecrf is required to perform dense crf inference." | |
| ) | |
| if isinstance(logits, torch.Tensor): | |
| logits = F.softmax(logits, dim=0).detach().cpu().numpy() | |
| U = unary_from_softmax(logits) | |
| n_labels = logits.shape[0] | |
| elif logits.ndim == 3: | |
| U = unary_from_softmax(logits) | |
| n_labels = logits.shape[0] | |
| else: | |
| assert n_labels is not None | |
| U = unary_from_labels(logits, n_labels, zero_unsure=False) | |
| d = dcrf.DenseCRF2D(image.shape[1], image.shape[0], n_labels) | |
| d.setUnaryEnergy(U) | |
| # This adds the color-independent term, features are the locations only. | |
| d.addPairwiseGaussian( | |
| sxy=pos_xy_std, | |
| compat=pos_w, | |
| kernel=dcrf.DIAG_KERNEL, | |
| normalization=dcrf.NORMALIZE_SYMMETRIC, | |
| ) | |
| # This adds the color-dependent term, i.e. features are (x,y,r,g,b). | |
| d.addPairwiseBilateral( | |
| sxy=bi_xy_std, | |
| srgb=bi_rgb_std, | |
| rgbim=image, | |
| compat=bi_w, | |
| kernel=dcrf.DIAG_KERNEL, | |
| normalization=dcrf.NORMALIZE_SYMMETRIC, | |
| ) | |
| # Run five inference steps. | |
| logits = d.inference(max_iters) | |
| logits = np.asarray(logits).reshape((n_labels, image.shape[0], image.shape[1])) | |
| return torch.from_numpy(logits) | |