Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Tuple, Optional, Dict | |
| import logging | |
| import os | |
| import shutil | |
| from os import path | |
| from PIL import Image | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import pycocotools.mask as mask_util | |
| from threading import Thread | |
| from queue import Queue | |
| from dataclasses import dataclass | |
| import copy | |
| from tracker.utils.pano_utils import ID2RGBConverter | |
| from tracker.utils.palette import davis_palette_np | |
| from tracker.inference.object_manager import ObjectManager | |
| from tracker.inference.object_info import ObjectInfo | |
| log = logging.getLogger() | |
| try: | |
| import hickle as hkl | |
| except ImportError: | |
| log.warning('Failed to import hickle. Fine if not using multi-scale testing.') | |
| class ResultSaver: | |
| def __init__(self, | |
| output_root, | |
| video_name, | |
| *, | |
| dataset, | |
| object_manager: ObjectManager, | |
| use_long_id, | |
| palette=None, | |
| save_mask=True, | |
| save_scores=False, | |
| score_output_root=None, | |
| visualize_output_root=None, | |
| visualize=False, | |
| init_json=None): | |
| self.output_root = output_root | |
| self.video_name = video_name | |
| self.dataset = dataset.lower() | |
| self.use_long_id = use_long_id | |
| self.palette = palette | |
| self.object_manager = object_manager | |
| self.save_mask = save_mask | |
| self.save_scores = save_scores | |
| self.score_output_root = score_output_root | |
| self.visualize_output_root = visualize_output_root | |
| self.visualize = visualize | |
| if self.visualize: | |
| if self.palette is not None: | |
| self.colors = np.array(self.palette, dtype=np.uint8).reshape(-1, 3) | |
| else: | |
| self.colors = davis_palette_np | |
| self.need_remapping = True | |
| self.json_style = None | |
| self.id2rgb_converter = ID2RGBConverter() | |
| if 'burst' in self.dataset: | |
| assert init_json is not None | |
| self.input_segmentations = init_json['segmentations'] | |
| self.segmentations = [{} for _ in init_json['segmentations']] | |
| self.annotated_frames = init_json['annotated_image_paths'] | |
| self.video_json = {k: v for k, v in init_json.items() if k != 'segmentations'} | |
| self.video_json['segmentations'] = self.segmentations | |
| self.json_style = 'burst' | |
| self.queue = Queue(maxsize=10) | |
| self.thread = Thread(target=save_result, args=(self.queue, )) | |
| self.thread.daemon = True | |
| self.thread.start() | |
| def process(self, | |
| prob: torch.Tensor, | |
| frame_name: str, | |
| resize_needed: bool = False, | |
| shape: Optional[Tuple[int, int]] = None, | |
| last_frame: bool = False, | |
| path_to_image: str = None): | |
| if resize_needed: | |
| prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:, | |
| 0] | |
| # Probability mask -> index mask | |
| mask = torch.argmax(prob, dim=0) | |
| if self.save_scores: | |
| # also need to pass prob | |
| prob = prob.cpu() | |
| else: | |
| prob = None | |
| # remap indices | |
| if self.need_remapping: | |
| new_mask = torch.zeros_like(mask) | |
| for tmp_id, obj in self.object_manager.tmp_id_to_obj.items(): | |
| new_mask[mask == tmp_id] = obj.id | |
| mask = new_mask | |
| args = ResultArgs(saver=self, | |
| prob=prob, | |
| mask=mask.cpu(), | |
| frame_name=frame_name, | |
| path_to_image=path_to_image, | |
| tmp_id_to_obj=copy.deepcopy(self.object_manager.tmp_id_to_obj), | |
| obj_to_tmp_id=copy.deepcopy(self.object_manager.obj_to_tmp_id), | |
| last_frame=last_frame) | |
| self.queue.put(args) | |
| def end(self): | |
| self.queue.put(None) | |
| self.queue.join() | |
| self.thread.join() | |
| class ResultArgs: | |
| saver: ResultSaver | |
| prob: torch.Tensor | |
| mask: torch.Tensor | |
| frame_name: str | |
| path_to_image: str | |
| tmp_id_to_obj: Dict[int, ObjectInfo] | |
| obj_to_tmp_id: Dict[ObjectInfo, int] | |
| last_frame: bool | |
| def save_result(queue: Queue): | |
| while True: | |
| args: ResultArgs = queue.get() | |
| if args is None: | |
| queue.task_done() | |
| break | |
| saver = args.saver | |
| prob = args.prob | |
| mask = args.mask | |
| frame_name = args.frame_name | |
| path_to_image = args.path_to_image | |
| tmp_id_to_obj = args.tmp_id_to_obj | |
| obj_to_tmp_id = args.obj_to_tmp_id | |
| last_frame = args.last_frame | |
| all_obj_ids = [k.id for k in obj_to_tmp_id] | |
| # record output in the json file | |
| if saver.json_style == 'burst': | |
| if frame_name in saver.annotated_frames: | |
| frame_index = saver.annotated_frames.index(frame_name) | |
| input_segments = saver.input_segmentations[frame_index] | |
| frame_segments = saver.segmentations[frame_index] | |
| for id in all_obj_ids: | |
| if id in input_segments: | |
| # if this frame has been given as input, just copy | |
| frame_segments[id] = input_segments[id] | |
| continue | |
| segment = {} | |
| segment_mask = (mask == id) | |
| if segment_mask.sum() > 0: | |
| coco_mask = mask_util.encode(np.asfortranarray(segment_mask.numpy())) | |
| segment['rle'] = coco_mask['counts'].decode('utf-8') | |
| frame_segments[id] = segment | |
| # save the mask to disk | |
| if saver.save_mask: | |
| if saver.use_long_id: | |
| out_mask = mask.numpy().astype(np.uint32) | |
| rgb_mask = np.zeros((*out_mask.shape[-2:], 3), dtype=np.uint8) | |
| for id in all_obj_ids: | |
| _, image = saver.id2rgb_converter.convert(id) | |
| obj_mask = (out_mask == id) | |
| rgb_mask[obj_mask] = image | |
| out_img = Image.fromarray(rgb_mask) | |
| else: | |
| rgb_mask = None | |
| out_mask = mask.numpy().astype(np.uint8) | |
| out_img = Image.fromarray(out_mask) | |
| if saver.palette is not None: | |
| out_img.putpalette(saver.palette) | |
| this_out_path = path.join(saver.output_root, saver.video_name) | |
| os.makedirs(this_out_path, exist_ok=True) | |
| out_img.save(os.path.join(this_out_path, frame_name[:-4] + '.png')) | |
| # save scores for multi-scale testing | |
| if saver.save_scores: | |
| this_out_path = path.join(saver.score_output_root, saver.video_name) | |
| os.makedirs(this_out_path, exist_ok=True) | |
| prob = (prob.detach().numpy() * 255).astype(np.uint8) | |
| if last_frame: | |
| tmp_to_obj_mapping = {obj.id: tmp_id for obj, tmp_id in tmp_id_to_obj.items()} | |
| hkl.dump(tmp_to_obj_mapping, path.join(this_out_path, f'backward.hkl'), mode='w') | |
| hkl.dump(prob, | |
| path.join(this_out_path, f'{frame_name[:-4]}.hkl'), | |
| mode='w', | |
| compression='lzf') | |
| if saver.visualize: | |
| if path_to_image is not None: | |
| image_np = np.array(Image.open(path_to_image)) | |
| else: | |
| raise ValueError('Cannot visualize without path_to_image') | |
| if rgb_mask is None: | |
| # we need to apply a palette | |
| rgb_mask = np.zeros((*out_mask.shape, 3), dtype=np.uint8) | |
| for id in all_obj_ids: | |
| image = saver.colors[id] | |
| obj_mask = (out_mask == id) | |
| rgb_mask[obj_mask] = image | |
| alpha = (out_mask == 0).astype(np.float32) * 0.5 + 0.5 | |
| alpha = alpha[:, :, None] | |
| blend = (image_np * alpha + rgb_mask * (1 - alpha)).astype(np.uint8) | |
| # find a place to save the visualization | |
| this_vis_path = path.join(saver.visualize_output_root, saver.video_name) | |
| os.makedirs(this_vis_path, exist_ok=True) | |
| Image.fromarray(blend).save(path.join(this_vis_path, frame_name[:-4] + '.jpg')) | |
| queue.task_done() | |
| def make_zip(dataset, run_dir, exp_id, mask_output_root): | |
| if dataset.startswith('y'): | |
| # YoutubeVOS | |
| log.info('Making zip for YouTubeVOS...') | |
| shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', run_dir, | |
| 'Annotations') | |
| elif dataset == 'd17-test-dev': | |
| # DAVIS 2017 test-dev -- zip from within the Annotation folder | |
| log.info('Making zip for DAVIS test-dev...') | |
| shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', mask_output_root) | |
| elif dataset == 'mose-val': | |
| # MOSE validation -- same as DAVIS test-dev | |
| log.info('Making zip for MOSE validation...') | |
| shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', mask_output_root) | |
| elif dataset == 'lvos-test': | |
| # LVOS test -- same as YouTubeVOS | |
| log.info('Making zip for LVOS test...') | |
| shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', run_dir, | |
| 'Annotations') | |
| else: | |
| log.info(f'Not making zip for {dataset}.') | |