| 
							 | 
						 | 
					
					
						
						| 
							 | 
						import logging | 
					
					
						
						| 
							 | 
						import os | 
					
					
						
						| 
							 | 
						import pickle | 
					
					
						
						| 
							 | 
						from urllib.parse import parse_qs, urlparse | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						from fvcore.common.checkpoint import Checkpointer | 
					
					
						
						| 
							 | 
						from torch.nn.parallel import DistributedDataParallel | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import detectron2.utils.comm as comm | 
					
					
						
						| 
							 | 
						from detectron2.utils.file_io import PathManager | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from .c2_model_loading import align_and_update_state_dicts | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class DetectionCheckpointer(Checkpointer): | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Same as :class:`Checkpointer`, but is able to: | 
					
					
						
						| 
							 | 
						    1. handle models in detectron & detectron2 model zoo, and apply conversions for legacy models. | 
					
					
						
						| 
							 | 
						    2. correctly load checkpoints that are only available on the master worker | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self, model, save_dir="", *, save_to_disk=None, **checkpointables): | 
					
					
						
						| 
							 | 
						        is_main_process = comm.is_main_process() | 
					
					
						
						| 
							 | 
						        super().__init__( | 
					
					
						
						| 
							 | 
						            model, | 
					
					
						
						| 
							 | 
						            save_dir, | 
					
					
						
						| 
							 | 
						            save_to_disk=is_main_process if save_to_disk is None else save_to_disk, | 
					
					
						
						| 
							 | 
						            **checkpointables, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        self.path_manager = PathManager | 
					
					
						
						| 
							 | 
						        self._parsed_url_during_load = None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def load(self, path, *args, **kwargs): | 
					
					
						
						| 
							 | 
						        assert self._parsed_url_during_load is None | 
					
					
						
						| 
							 | 
						        need_sync = False | 
					
					
						
						| 
							 | 
						        logger = logging.getLogger(__name__) | 
					
					
						
						| 
							 | 
						        logger.info("[DetectionCheckpointer] Loading from {} ...".format(path)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if path and isinstance(self.model, DistributedDataParallel): | 
					
					
						
						| 
							 | 
						            path = self.path_manager.get_local_path(path) | 
					
					
						
						| 
							 | 
						            has_file = os.path.isfile(path) | 
					
					
						
						| 
							 | 
						            all_has_file = comm.all_gather(has_file) | 
					
					
						
						| 
							 | 
						            if not all_has_file[0]: | 
					
					
						
						| 
							 | 
						                raise OSError(f"File {path} not found on main worker.") | 
					
					
						
						| 
							 | 
						            if not all(all_has_file): | 
					
					
						
						| 
							 | 
						                logger.warning( | 
					
					
						
						| 
							 | 
						                    f"Not all workers can read checkpoint {path}. " | 
					
					
						
						| 
							 | 
						                    "Training may fail to fully resume." | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                need_sync = True | 
					
					
						
						| 
							 | 
						            if not has_file: | 
					
					
						
						| 
							 | 
						                path = None   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if path: | 
					
					
						
						| 
							 | 
						            parsed_url = urlparse(path) | 
					
					
						
						| 
							 | 
						            self._parsed_url_during_load = parsed_url | 
					
					
						
						| 
							 | 
						            path = parsed_url._replace(query="").geturl()   | 
					
					
						
						| 
							 | 
						            path = self.path_manager.get_local_path(path) | 
					
					
						
						| 
							 | 
						        ret = super().load(path, *args, **kwargs) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if need_sync: | 
					
					
						
						| 
							 | 
						            logger.info("Broadcasting model states from main worker ...") | 
					
					
						
						| 
							 | 
						            self.model._sync_params_and_buffers() | 
					
					
						
						| 
							 | 
						        self._parsed_url_during_load = None   | 
					
					
						
						| 
							 | 
						        return ret | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _load_file(self, filename): | 
					
					
						
						| 
							 | 
						        if filename.endswith(".pkl"): | 
					
					
						
						| 
							 | 
						            with PathManager.open(filename, "rb") as f: | 
					
					
						
						| 
							 | 
						                data = pickle.load(f, encoding="latin1") | 
					
					
						
						| 
							 | 
						            if "model" in data and "__author__" in data: | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                self.logger.info("Reading a file from '{}'".format(data["__author__"])) | 
					
					
						
						| 
							 | 
						                return data | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                if "blobs" in data: | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    data = data["blobs"] | 
					
					
						
						| 
							 | 
						                data = {k: v for k, v in data.items() if not k.endswith("_momentum")} | 
					
					
						
						| 
							 | 
						                return {"model": data, "__author__": "Caffe2", "matching_heuristics": True} | 
					
					
						
						| 
							 | 
						        elif filename.endswith(".pyth"): | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            with PathManager.open(filename, "rb") as f: | 
					
					
						
						| 
							 | 
						                data = torch.load(f) | 
					
					
						
						| 
							 | 
						            assert ( | 
					
					
						
						| 
							 | 
						                "model_state" in data | 
					
					
						
						| 
							 | 
						            ), f"Cannot load .pyth file {filename}; pycls checkpoints must contain 'model_state'." | 
					
					
						
						| 
							 | 
						            model_state = { | 
					
					
						
						| 
							 | 
						                k: v | 
					
					
						
						| 
							 | 
						                for k, v in data["model_state"].items() | 
					
					
						
						| 
							 | 
						                if not k.endswith("num_batches_tracked") | 
					
					
						
						| 
							 | 
						            } | 
					
					
						
						| 
							 | 
						            return {"model": model_state, "__author__": "pycls", "matching_heuristics": True} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        loaded = self._torch_load(filename) | 
					
					
						
						| 
							 | 
						        if "model" not in loaded: | 
					
					
						
						| 
							 | 
						            loaded = {"model": loaded} | 
					
					
						
						| 
							 | 
						        assert self._parsed_url_during_load is not None, "`_load_file` must be called inside `load`" | 
					
					
						
						| 
							 | 
						        parsed_url = self._parsed_url_during_load | 
					
					
						
						| 
							 | 
						        queries = parse_qs(parsed_url.query) | 
					
					
						
						| 
							 | 
						        if queries.pop("matching_heuristics", "False") == ["True"]: | 
					
					
						
						| 
							 | 
						            loaded["matching_heuristics"] = True | 
					
					
						
						| 
							 | 
						        if len(queries) > 0: | 
					
					
						
						| 
							 | 
						            raise ValueError( | 
					
					
						
						| 
							 | 
						                f"Unsupported query remaining: f{queries}, orginal filename: {parsed_url.geturl()}" | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						        return loaded | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _torch_load(self, f): | 
					
					
						
						| 
							 | 
						        return super()._load_file(f) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _load_model(self, checkpoint): | 
					
					
						
						| 
							 | 
						        if checkpoint.get("matching_heuristics", False): | 
					
					
						
						| 
							 | 
						            self._convert_ndarray_to_tensor(checkpoint["model"]) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            checkpoint["model"] = align_and_update_state_dicts( | 
					
					
						
						| 
							 | 
						                self.model.state_dict(), | 
					
					
						
						| 
							 | 
						                checkpoint["model"], | 
					
					
						
						| 
							 | 
						                c2_conversion=checkpoint.get("__author__", None) == "Caffe2", | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        incompatible = super()._load_model(checkpoint) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        model_buffers = dict(self.model.named_buffers(recurse=False)) | 
					
					
						
						| 
							 | 
						        for k in ["pixel_mean", "pixel_std"]: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            if k in model_buffers: | 
					
					
						
						| 
							 | 
						                try: | 
					
					
						
						| 
							 | 
						                    incompatible.missing_keys.remove(k) | 
					
					
						
						| 
							 | 
						                except ValueError: | 
					
					
						
						| 
							 | 
						                    pass | 
					
					
						
						| 
							 | 
						        for k in incompatible.unexpected_keys[:]: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            if "anchor_generator.cell_anchors" in k: | 
					
					
						
						| 
							 | 
						                incompatible.unexpected_keys.remove(k) | 
					
					
						
						| 
							 | 
						        return incompatible | 
					
					
						
						| 
							 | 
						
 |