|  |  | 
					
						
						|  | import logging | 
					
						
						|  | import os | 
					
						
						|  | import pickle | 
					
						
						|  | from urllib.parse import parse_qs, urlparse | 
					
						
						|  | import torch | 
					
						
						|  | from fvcore.common.checkpoint import Checkpointer | 
					
						
						|  | from torch.nn.parallel import DistributedDataParallel | 
					
						
						|  |  | 
					
						
						|  | import detectron2.utils.comm as comm | 
					
						
						|  | from detectron2.utils.file_io import PathManager | 
					
						
						|  |  | 
					
						
						|  | from .c2_model_loading import align_and_update_state_dicts | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class DetectionCheckpointer(Checkpointer): | 
					
						
						|  | """ | 
					
						
						|  | Same as :class:`Checkpointer`, but is able to: | 
					
						
						|  | 1. handle models in detectron & detectron2 model zoo, and apply conversions for legacy models. | 
					
						
						|  | 2. correctly load checkpoints that are only available on the master worker | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, model, save_dir="", *, save_to_disk=None, **checkpointables): | 
					
						
						|  | is_main_process = comm.is_main_process() | 
					
						
						|  | super().__init__( | 
					
						
						|  | model, | 
					
						
						|  | save_dir, | 
					
						
						|  | save_to_disk=is_main_process if save_to_disk is None else save_to_disk, | 
					
						
						|  | **checkpointables, | 
					
						
						|  | ) | 
					
						
						|  | self.path_manager = PathManager | 
					
						
						|  | self._parsed_url_during_load = None | 
					
						
						|  |  | 
					
						
						|  | def load(self, path, *args, **kwargs): | 
					
						
						|  | assert self._parsed_url_during_load is None | 
					
						
						|  | need_sync = False | 
					
						
						|  | logger = logging.getLogger(__name__) | 
					
						
						|  | logger.info("[DetectionCheckpointer] Loading from {} ...".format(path)) | 
					
						
						|  |  | 
					
						
						|  | if path and isinstance(self.model, DistributedDataParallel): | 
					
						
						|  | path = self.path_manager.get_local_path(path) | 
					
						
						|  | has_file = os.path.isfile(path) | 
					
						
						|  | all_has_file = comm.all_gather(has_file) | 
					
						
						|  | if not all_has_file[0]: | 
					
						
						|  | raise OSError(f"File {path} not found on main worker.") | 
					
						
						|  | if not all(all_has_file): | 
					
						
						|  | logger.warning( | 
					
						
						|  | f"Not all workers can read checkpoint {path}. " | 
					
						
						|  | "Training may fail to fully resume." | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | need_sync = True | 
					
						
						|  | if not has_file: | 
					
						
						|  | path = None | 
					
						
						|  |  | 
					
						
						|  | if path: | 
					
						
						|  | parsed_url = urlparse(path) | 
					
						
						|  | self._parsed_url_during_load = parsed_url | 
					
						
						|  | path = parsed_url._replace(query="").geturl() | 
					
						
						|  | path = self.path_manager.get_local_path(path) | 
					
						
						|  | ret = super().load(path, *args, **kwargs) | 
					
						
						|  |  | 
					
						
						|  | if need_sync: | 
					
						
						|  | logger.info("Broadcasting model states from main worker ...") | 
					
						
						|  | self.model._sync_params_and_buffers() | 
					
						
						|  | self._parsed_url_during_load = None | 
					
						
						|  | return ret | 
					
						
						|  |  | 
					
						
						|  | def _load_file(self, filename): | 
					
						
						|  | if filename.endswith(".pkl"): | 
					
						
						|  | with PathManager.open(filename, "rb") as f: | 
					
						
						|  | data = pickle.load(f, encoding="latin1") | 
					
						
						|  | if "model" in data and "__author__" in data: | 
					
						
						|  |  | 
					
						
						|  | self.logger.info("Reading a file from '{}'".format(data["__author__"])) | 
					
						
						|  | return data | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | if "blobs" in data: | 
					
						
						|  |  | 
					
						
						|  | data = data["blobs"] | 
					
						
						|  | data = {k: v for k, v in data.items() if not k.endswith("_momentum")} | 
					
						
						|  | return {"model": data, "__author__": "Caffe2", "matching_heuristics": True} | 
					
						
						|  | elif filename.endswith(".pyth"): | 
					
						
						|  |  | 
					
						
						|  | with PathManager.open(filename, "rb") as f: | 
					
						
						|  | data = torch.load(f) | 
					
						
						|  | assert ( | 
					
						
						|  | "model_state" in data | 
					
						
						|  | ), f"Cannot load .pyth file {filename}; pycls checkpoints must contain 'model_state'." | 
					
						
						|  | model_state = { | 
					
						
						|  | k: v | 
					
						
						|  | for k, v in data["model_state"].items() | 
					
						
						|  | if not k.endswith("num_batches_tracked") | 
					
						
						|  | } | 
					
						
						|  | return {"model": model_state, "__author__": "pycls", "matching_heuristics": True} | 
					
						
						|  |  | 
					
						
						|  | loaded = self._torch_load(filename) | 
					
						
						|  | if "model" not in loaded: | 
					
						
						|  | loaded = {"model": loaded} | 
					
						
						|  | assert self._parsed_url_during_load is not None, "`_load_file` must be called inside `load`" | 
					
						
						|  | parsed_url = self._parsed_url_during_load | 
					
						
						|  | queries = parse_qs(parsed_url.query) | 
					
						
						|  | if queries.pop("matching_heuristics", "False") == ["True"]: | 
					
						
						|  | loaded["matching_heuristics"] = True | 
					
						
						|  | if len(queries) > 0: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f"Unsupported query remaining: f{queries}, orginal filename: {parsed_url.geturl()}" | 
					
						
						|  | ) | 
					
						
						|  | return loaded | 
					
						
						|  |  | 
					
						
						|  | def _torch_load(self, f): | 
					
						
						|  | return super()._load_file(f) | 
					
						
						|  |  | 
					
						
						|  | def _load_model(self, checkpoint): | 
					
						
						|  | if checkpoint.get("matching_heuristics", False): | 
					
						
						|  | self._convert_ndarray_to_tensor(checkpoint["model"]) | 
					
						
						|  |  | 
					
						
						|  | checkpoint["model"] = align_and_update_state_dicts( | 
					
						
						|  | self.model.state_dict(), | 
					
						
						|  | checkpoint["model"], | 
					
						
						|  | c2_conversion=checkpoint.get("__author__", None) == "Caffe2", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | incompatible = super()._load_model(checkpoint) | 
					
						
						|  |  | 
					
						
						|  | model_buffers = dict(self.model.named_buffers(recurse=False)) | 
					
						
						|  | for k in ["pixel_mean", "pixel_std"]: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if k in model_buffers: | 
					
						
						|  | try: | 
					
						
						|  | incompatible.missing_keys.remove(k) | 
					
						
						|  | except ValueError: | 
					
						
						|  | pass | 
					
						
						|  | for k in incompatible.unexpected_keys[:]: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if "anchor_generator.cell_anchors" in k: | 
					
						
						|  | incompatible.unexpected_keys.remove(k) | 
					
						
						|  | return incompatible | 
					
						
						|  |  |