|  |  | 
					
						
						|  |  | 
					
						
						|  | import contextlib | 
					
						
						|  | from unittest import mock | 
					
						
						|  | import torch | 
					
						
						|  |  | 
					
						
						|  | from detectron2.modeling import poolers | 
					
						
						|  | from detectron2.modeling.proposal_generator import rpn | 
					
						
						|  | from detectron2.modeling.roi_heads import keypoint_head, mask_head | 
					
						
						|  | from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers | 
					
						
						|  |  | 
					
						
						|  | from .c10 import ( | 
					
						
						|  | Caffe2Compatible, | 
					
						
						|  | Caffe2FastRCNNOutputsInference, | 
					
						
						|  | Caffe2KeypointRCNNInference, | 
					
						
						|  | Caffe2MaskRCNNInference, | 
					
						
						|  | Caffe2ROIPooler, | 
					
						
						|  | Caffe2RPN, | 
					
						
						|  | caffe2_fast_rcnn_outputs_inference, | 
					
						
						|  | caffe2_keypoint_rcnn_inference, | 
					
						
						|  | caffe2_mask_rcnn_inference, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class GenericMixin: | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Caffe2CompatibleConverter: | 
					
						
						|  | """ | 
					
						
						|  | A GenericUpdater which implements the `create_from` interface, by modifying | 
					
						
						|  | module object and assign it with another class replaceCls. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, replaceCls): | 
					
						
						|  | self.replaceCls = replaceCls | 
					
						
						|  |  | 
					
						
						|  | def create_from(self, module): | 
					
						
						|  |  | 
					
						
						|  | assert isinstance(module, torch.nn.Module) | 
					
						
						|  | if issubclass(self.replaceCls, GenericMixin): | 
					
						
						|  |  | 
					
						
						|  | new_class = type( | 
					
						
						|  | "{}MixedWith{}".format(self.replaceCls.__name__, module.__class__.__name__), | 
					
						
						|  | (self.replaceCls, module.__class__), | 
					
						
						|  | {}, | 
					
						
						|  | ) | 
					
						
						|  | module.__class__ = new_class | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | module.__class__ = self.replaceCls | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if isinstance(module, Caffe2Compatible): | 
					
						
						|  | module.tensor_mode = False | 
					
						
						|  |  | 
					
						
						|  | return module | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def patch(model, target, updater, *args, **kwargs): | 
					
						
						|  | """ | 
					
						
						|  | recursively (post-order) update all modules with the target type and its | 
					
						
						|  | subclasses, make a initialization/composition/inheritance/... via the | 
					
						
						|  | updater.create_from. | 
					
						
						|  | """ | 
					
						
						|  | for name, module in model.named_children(): | 
					
						
						|  | model._modules[name] = patch(module, target, updater, *args, **kwargs) | 
					
						
						|  | if isinstance(model, target): | 
					
						
						|  | return updater.create_from(model, *args, **kwargs) | 
					
						
						|  | return model | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def patch_generalized_rcnn(model): | 
					
						
						|  | ccc = Caffe2CompatibleConverter | 
					
						
						|  | model = patch(model, rpn.RPN, ccc(Caffe2RPN)) | 
					
						
						|  | model = patch(model, poolers.ROIPooler, ccc(Caffe2ROIPooler)) | 
					
						
						|  |  | 
					
						
						|  | return model | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @contextlib.contextmanager | 
					
						
						|  | def mock_fastrcnn_outputs_inference( | 
					
						
						|  | tensor_mode, check=True, box_predictor_type=FastRCNNOutputLayers | 
					
						
						|  | ): | 
					
						
						|  | with mock.patch.object( | 
					
						
						|  | box_predictor_type, | 
					
						
						|  | "inference", | 
					
						
						|  | autospec=True, | 
					
						
						|  | side_effect=Caffe2FastRCNNOutputsInference(tensor_mode), | 
					
						
						|  | ) as mocked_func: | 
					
						
						|  | yield | 
					
						
						|  | if check: | 
					
						
						|  | assert mocked_func.call_count > 0 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @contextlib.contextmanager | 
					
						
						|  | def mock_mask_rcnn_inference(tensor_mode, patched_module, check=True): | 
					
						
						|  | with mock.patch( | 
					
						
						|  | "{}.mask_rcnn_inference".format(patched_module), side_effect=Caffe2MaskRCNNInference() | 
					
						
						|  | ) as mocked_func: | 
					
						
						|  | yield | 
					
						
						|  | if check: | 
					
						
						|  | assert mocked_func.call_count > 0 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @contextlib.contextmanager | 
					
						
						|  | def mock_keypoint_rcnn_inference(tensor_mode, patched_module, use_heatmap_max_keypoint, check=True): | 
					
						
						|  | with mock.patch( | 
					
						
						|  | "{}.keypoint_rcnn_inference".format(patched_module), | 
					
						
						|  | side_effect=Caffe2KeypointRCNNInference(use_heatmap_max_keypoint), | 
					
						
						|  | ) as mocked_func: | 
					
						
						|  | yield | 
					
						
						|  | if check: | 
					
						
						|  | assert mocked_func.call_count > 0 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ROIHeadsPatcher: | 
					
						
						|  | def __init__(self, heads, use_heatmap_max_keypoint): | 
					
						
						|  | self.heads = heads | 
					
						
						|  | self.use_heatmap_max_keypoint = use_heatmap_max_keypoint | 
					
						
						|  | self.previous_patched = {} | 
					
						
						|  |  | 
					
						
						|  | @contextlib.contextmanager | 
					
						
						|  | def mock_roi_heads(self, tensor_mode=True): | 
					
						
						|  | """ | 
					
						
						|  | Patching several inference functions inside ROIHeads and its subclasses | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | tensor_mode (bool): whether the inputs/outputs are caffe2's tensor | 
					
						
						|  | format or not. Default to True. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | kpt_heads_mod = keypoint_head.BaseKeypointRCNNHead.__module__ | 
					
						
						|  | mask_head_mod = mask_head.BaseMaskRCNNHead.__module__ | 
					
						
						|  |  | 
					
						
						|  | mock_ctx_managers = [ | 
					
						
						|  | mock_fastrcnn_outputs_inference( | 
					
						
						|  | tensor_mode=tensor_mode, | 
					
						
						|  | check=True, | 
					
						
						|  | box_predictor_type=type(self.heads.box_predictor), | 
					
						
						|  | ) | 
					
						
						|  | ] | 
					
						
						|  | if getattr(self.heads, "keypoint_on", False): | 
					
						
						|  | mock_ctx_managers += [ | 
					
						
						|  | mock_keypoint_rcnn_inference( | 
					
						
						|  | tensor_mode, kpt_heads_mod, self.use_heatmap_max_keypoint | 
					
						
						|  | ) | 
					
						
						|  | ] | 
					
						
						|  | if getattr(self.heads, "mask_on", False): | 
					
						
						|  | mock_ctx_managers += [mock_mask_rcnn_inference(tensor_mode, mask_head_mod)] | 
					
						
						|  |  | 
					
						
						|  | with contextlib.ExitStack() as stack: | 
					
						
						|  | for mgr in mock_ctx_managers: | 
					
						
						|  | stack.enter_context(mgr) | 
					
						
						|  | yield | 
					
						
						|  |  | 
					
						
						|  | def patch_roi_heads(self, tensor_mode=True): | 
					
						
						|  | self.previous_patched["box_predictor"] = self.heads.box_predictor.inference | 
					
						
						|  | self.previous_patched["keypoint_rcnn"] = keypoint_head.keypoint_rcnn_inference | 
					
						
						|  | self.previous_patched["mask_rcnn"] = mask_head.mask_rcnn_inference | 
					
						
						|  |  | 
					
						
						|  | def patched_fastrcnn_outputs_inference(predictions, proposal): | 
					
						
						|  | return caffe2_fast_rcnn_outputs_inference( | 
					
						
						|  | True, self.heads.box_predictor, predictions, proposal | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.heads.box_predictor.inference = patched_fastrcnn_outputs_inference | 
					
						
						|  |  | 
					
						
						|  | if getattr(self.heads, "keypoint_on", False): | 
					
						
						|  |  | 
					
						
						|  | def patched_keypoint_rcnn_inference(pred_keypoint_logits, pred_instances): | 
					
						
						|  | return caffe2_keypoint_rcnn_inference( | 
					
						
						|  | self.use_heatmap_max_keypoint, pred_keypoint_logits, pred_instances | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | keypoint_head.keypoint_rcnn_inference = patched_keypoint_rcnn_inference | 
					
						
						|  |  | 
					
						
						|  | if getattr(self.heads, "mask_on", False): | 
					
						
						|  |  | 
					
						
						|  | def patched_mask_rcnn_inference(pred_mask_logits, pred_instances): | 
					
						
						|  | return caffe2_mask_rcnn_inference(pred_mask_logits, pred_instances) | 
					
						
						|  |  | 
					
						
						|  | mask_head.mask_rcnn_inference = patched_mask_rcnn_inference | 
					
						
						|  |  | 
					
						
						|  | def unpatch_roi_heads(self): | 
					
						
						|  | self.heads.box_predictor.inference = self.previous_patched["box_predictor"] | 
					
						
						|  | keypoint_head.keypoint_rcnn_inference = self.previous_patched["keypoint_rcnn"] | 
					
						
						|  | mask_head.mask_rcnn_inference = self.previous_patched["mask_rcnn"] | 
					
						
						|  |  |