Spaces:
Running
Running
| # Copyright (c) EPFL VILAB. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from typing import Dict, List | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.transforms as transforms | |
| import wandb | |
| import utils | |
| from utils.datasets_semseg import (ade_classes, hypersim_classes, | |
| nyu_v2_40_classes) | |
| def inv_norm(tensor: torch.Tensor) -> torch.Tensor: | |
| """Inverse of the normalization that was done during pre-processing | |
| """ | |
| inv_normalize = transforms.Normalize( | |
| mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225], | |
| std=[1 / 0.229, 1 / 0.224, 1 / 0.225]) | |
| return inv_normalize(tensor) | |
| def log_semseg_wandb( | |
| images: torch.Tensor, | |
| preds: List[np.ndarray], | |
| gts: List[np.ndarray], | |
| depth_gts: List[np.ndarray], | |
| dataset_name: str = 'ade20k', | |
| image_count=8, | |
| prefix="" | |
| ): | |
| if dataset_name == 'ade20k': | |
| classes = ade_classes() | |
| elif dataset_name == 'hypersim': | |
| classes = hypersim_classes() | |
| elif dataset_name == 'nyu': | |
| classes = nyu_v2_40_classes() | |
| else: | |
| raise ValueError(f'Dataset {dataset_name} not supported for logging to wandb.') | |
| class_labels = {i: cls for i, cls in enumerate(classes)} | |
| class_labels[len(classes)] = "void" | |
| class_labels[utils.SEG_IGNORE_INDEX] = "ignore" | |
| image_count = min(len(images), image_count) | |
| images = images[:image_count] | |
| preds = preds[:image_count] | |
| gts = gts[:image_count] | |
| depth_gts = depth_gts[:image_count] if len(depth_gts) > 0 else None | |
| semseg_images = {} | |
| for i, (image, pred, gt) in enumerate(zip(images, preds, gts)): | |
| image = inv_norm(image) | |
| pred[gt == utils.SEG_IGNORE_INDEX] = utils.SEG_IGNORE_INDEX | |
| semseg_image = wandb.Image(image, masks={ | |
| "predictions": { | |
| "mask_data": pred, | |
| "class_labels": class_labels, | |
| }, | |
| "ground_truth": { | |
| "mask_data": gt, | |
| "class_labels": class_labels, | |
| } | |
| }) | |
| semseg_images[f"{prefix}_{i}"] = semseg_image | |
| if depth_gts is not None: | |
| semseg_images[f"{prefix}_{i}_depth"] = wandb.Image(depth_gts[i]) | |
| wandb.log(semseg_images, commit=False) | |
| def log_taskonomy_wandb( | |
| preds: Dict[str, torch.Tensor], | |
| gts: Dict[str, torch.Tensor], | |
| image_count=8, | |
| prefix="" | |
| ): | |
| pred_tasks = list(preds.keys()) | |
| gt_tasks = list(gts.keys()) | |
| if 'mask_valid' in gt_tasks: | |
| gt_tasks.remove('mask_valid') | |
| image_count = min(len(preds[pred_tasks[0]]), image_count) | |
| all_images = {} | |
| for i in range(image_count): | |
| # Log GTs | |
| for task in gt_tasks: | |
| gt_img = gts[task][i] | |
| if task == 'rgb': | |
| gt_img = inv_norm(gt_img) | |
| if gt_img.shape[0] == 1: | |
| gt_img = gt_img[0] | |
| elif gt_img.shape[0] == 2: | |
| gt_img = F.pad(gt_img, (0,0,0,0,0,1), mode='constant', value=0.0) | |
| gt_img = wandb.Image(gt_img, caption=f'GT #{i}') | |
| key = f'{prefix}_gt_{task}' | |
| if key not in all_images: | |
| all_images[key] = [gt_img] | |
| else: | |
| all_images[key].append(gt_img) | |
| # Log preds | |
| for task in pred_tasks: | |
| pred_img = preds[task][i] | |
| if task == 'rgb': | |
| pred_img = inv_norm(pred_img) | |
| if pred_img.shape[0] == 1: | |
| pred_img = pred_img[0] | |
| elif pred_img.shape[0] == 2: | |
| pred_img = F.pad(pred_img, (0,0,0,0,0,1), mode='constant', value=0.0) | |
| pred_img = wandb.Image(pred_img, caption=f'Pred #{i}') | |
| key = f'{prefix}_pred_{task}' | |
| if key not in all_images: | |
| all_images[key] = [pred_img] | |
| else: | |
| all_images[key].append(pred_img) | |
| wandb.log(all_images, commit=False) | |