|  |  | 
					
						
						|  | import collections | 
					
						
						|  | from dataclasses import dataclass | 
					
						
						|  | from typing import Callable, List, Optional, Tuple | 
					
						
						|  | import torch | 
					
						
						|  | from torch import nn | 
					
						
						|  |  | 
					
						
						|  | from detectron2.structures import Boxes, Instances, ROIMasks | 
					
						
						|  | from detectron2.utils.registry import _convert_target_to_string, locate | 
					
						
						|  |  | 
					
						
						|  | from .torchscript_patch import patch_builtin_len | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class Schema: | 
					
						
						|  | """ | 
					
						
						|  | A Schema defines how to flatten a possibly hierarchical object into tuple of | 
					
						
						|  | primitive objects, so it can be used as inputs/outputs of PyTorch's tracing. | 
					
						
						|  |  | 
					
						
						|  | PyTorch does not support tracing a function that produces rich output | 
					
						
						|  | structures (e.g. dict, Instances, Boxes). To trace such a function, we | 
					
						
						|  | flatten the rich object into tuple of tensors, and return this tuple of tensors | 
					
						
						|  | instead. Meanwhile, we also need to know how to "rebuild" the original object | 
					
						
						|  | from the flattened results, so we can evaluate the flattened results. | 
					
						
						|  | A Schema defines how to flatten an object, and while flattening it, it records | 
					
						
						|  | necessary schemas so that the object can be rebuilt using the flattened outputs. | 
					
						
						|  |  | 
					
						
						|  | The flattened object and the schema object is returned by ``.flatten`` classmethod. | 
					
						
						|  | Then the original object can be rebuilt with the ``__call__`` method of schema. | 
					
						
						|  |  | 
					
						
						|  | A Schema is a dataclass that can be serialized easily. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def flatten(cls, obj): | 
					
						
						|  | raise NotImplementedError | 
					
						
						|  |  | 
					
						
						|  | def __call__(self, values): | 
					
						
						|  | raise NotImplementedError | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def _concat(values): | 
					
						
						|  | ret = () | 
					
						
						|  | sizes = [] | 
					
						
						|  | for v in values: | 
					
						
						|  | assert isinstance(v, tuple), "Flattened results must be a tuple" | 
					
						
						|  | ret = ret + v | 
					
						
						|  | sizes.append(len(v)) | 
					
						
						|  | return ret, sizes | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def _split(values, sizes): | 
					
						
						|  | if len(sizes): | 
					
						
						|  | expected_len = sum(sizes) | 
					
						
						|  | assert ( | 
					
						
						|  | len(values) == expected_len | 
					
						
						|  | ), f"Values has length {len(values)} but expect length {expected_len}." | 
					
						
						|  | ret = [] | 
					
						
						|  | for k in range(len(sizes)): | 
					
						
						|  | begin, end = sum(sizes[:k]), sum(sizes[: k + 1]) | 
					
						
						|  | ret.append(values[begin:end]) | 
					
						
						|  | return ret | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class ListSchema(Schema): | 
					
						
						|  | schemas: List[Schema] | 
					
						
						|  | sizes: List[int] | 
					
						
						|  |  | 
					
						
						|  | def __call__(self, values): | 
					
						
						|  | values = self._split(values, self.sizes) | 
					
						
						|  | if len(values) != len(self.schemas): | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f"Values has length {len(values)} but schemas " f"has length {len(self.schemas)}!" | 
					
						
						|  | ) | 
					
						
						|  | values = [m(v) for m, v in zip(self.schemas, values)] | 
					
						
						|  | return list(values) | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def flatten(cls, obj): | 
					
						
						|  | res = [flatten_to_tuple(k) for k in obj] | 
					
						
						|  | values, sizes = cls._concat([k[0] for k in res]) | 
					
						
						|  | return values, cls([k[1] for k in res], sizes) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class TupleSchema(ListSchema): | 
					
						
						|  | def __call__(self, values): | 
					
						
						|  | return tuple(super().__call__(values)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class IdentitySchema(Schema): | 
					
						
						|  | def __call__(self, values): | 
					
						
						|  | return values[0] | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def flatten(cls, obj): | 
					
						
						|  | return (obj,), cls() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class DictSchema(ListSchema): | 
					
						
						|  | keys: List[str] | 
					
						
						|  |  | 
					
						
						|  | def __call__(self, values): | 
					
						
						|  | values = super().__call__(values) | 
					
						
						|  | return dict(zip(self.keys, values)) | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def flatten(cls, obj): | 
					
						
						|  | for k in obj.keys(): | 
					
						
						|  | if not isinstance(k, str): | 
					
						
						|  | raise KeyError("Only support flattening dictionaries if keys are str.") | 
					
						
						|  | keys = sorted(obj.keys()) | 
					
						
						|  | values = [obj[k] for k in keys] | 
					
						
						|  | ret, schema = ListSchema.flatten(values) | 
					
						
						|  | return ret, cls(schema.schemas, schema.sizes, keys) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class InstancesSchema(DictSchema): | 
					
						
						|  | def __call__(self, values): | 
					
						
						|  | image_size, fields = values[-1], values[:-1] | 
					
						
						|  | fields = super().__call__(fields) | 
					
						
						|  | return Instances(image_size, **fields) | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def flatten(cls, obj): | 
					
						
						|  | ret, schema = super().flatten(obj.get_fields()) | 
					
						
						|  | size = obj.image_size | 
					
						
						|  | if not isinstance(size, torch.Tensor): | 
					
						
						|  | size = torch.tensor(size) | 
					
						
						|  | return ret + (size,), schema | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class TensorWrapSchema(Schema): | 
					
						
						|  | """ | 
					
						
						|  | For classes that are simple wrapper of tensors, e.g. | 
					
						
						|  | Boxes, RotatedBoxes, BitMasks | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | class_name: str | 
					
						
						|  |  | 
					
						
						|  | def __call__(self, values): | 
					
						
						|  | return locate(self.class_name)(values[0]) | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def flatten(cls, obj): | 
					
						
						|  | return (obj.tensor,), cls(_convert_target_to_string(type(obj))) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def flatten_to_tuple(obj): | 
					
						
						|  | """ | 
					
						
						|  | Flatten an object so it can be used for PyTorch tracing. | 
					
						
						|  | Also returns how to rebuild the original object from the flattened outputs. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | res (tuple): the flattened results that can be used as tracing outputs | 
					
						
						|  | schema: an object with a ``__call__`` method such that ``schema(res) == obj``. | 
					
						
						|  | It is a pure dataclass that can be serialized. | 
					
						
						|  | """ | 
					
						
						|  | schemas = [ | 
					
						
						|  | ((str, bytes), IdentitySchema), | 
					
						
						|  | (list, ListSchema), | 
					
						
						|  | (tuple, TupleSchema), | 
					
						
						|  | (collections.abc.Mapping, DictSchema), | 
					
						
						|  | (Instances, InstancesSchema), | 
					
						
						|  | ((Boxes, ROIMasks), TensorWrapSchema), | 
					
						
						|  | ] | 
					
						
						|  | for klass, schema in schemas: | 
					
						
						|  | if isinstance(obj, klass): | 
					
						
						|  | F = schema | 
					
						
						|  | break | 
					
						
						|  | else: | 
					
						
						|  | F = IdentitySchema | 
					
						
						|  |  | 
					
						
						|  | return F.flatten(obj) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class TracingAdapter(nn.Module): | 
					
						
						|  | """ | 
					
						
						|  | A model may take rich input/output format (e.g. dict or custom classes), | 
					
						
						|  | but `torch.jit.trace` requires tuple of tensors as input/output. | 
					
						
						|  | This adapter flattens input/output format of a model so it becomes traceable. | 
					
						
						|  |  | 
					
						
						|  | It also records the necessary schema to rebuild model's inputs/outputs from flattened | 
					
						
						|  | inputs/outputs. | 
					
						
						|  |  | 
					
						
						|  | Example: | 
					
						
						|  | :: | 
					
						
						|  | outputs = model(inputs)   # inputs/outputs may be rich structure | 
					
						
						|  | adapter = TracingAdapter(model, inputs) | 
					
						
						|  |  | 
					
						
						|  | # can now trace the model, with adapter.flattened_inputs, or another | 
					
						
						|  | # tuple of tensors with the same length and meaning | 
					
						
						|  | traced = torch.jit.trace(adapter, adapter.flattened_inputs) | 
					
						
						|  |  | 
					
						
						|  | # traced model can only produce flattened outputs (tuple of tensors) | 
					
						
						|  | flattened_outputs = traced(*adapter.flattened_inputs) | 
					
						
						|  | # adapter knows the schema to convert it back (new_outputs == outputs) | 
					
						
						|  | new_outputs = adapter.outputs_schema(flattened_outputs) | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | flattened_inputs: Tuple[torch.Tensor] = None | 
					
						
						|  | """ | 
					
						
						|  | Flattened version of inputs given to this class's constructor. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | inputs_schema: Schema = None | 
					
						
						|  | """ | 
					
						
						|  | Schema of the inputs given to this class's constructor. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | outputs_schema: Schema = None | 
					
						
						|  | """ | 
					
						
						|  | Schema of the output produced by calling the given model with inputs. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | model: nn.Module, | 
					
						
						|  | inputs, | 
					
						
						|  | inference_func: Optional[Callable] = None, | 
					
						
						|  | allow_non_tensor: bool = False, | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Args: | 
					
						
						|  | model: an nn.Module | 
					
						
						|  | inputs: An input argument or a tuple of input arguments used to call model. | 
					
						
						|  | After flattening, it has to only consist of tensors. | 
					
						
						|  | inference_func: a callable that takes (model, *inputs), calls the | 
					
						
						|  | model with inputs, and return outputs. By default it | 
					
						
						|  | is ``lambda model, *inputs: model(*inputs)``. Can be override | 
					
						
						|  | if you need to call the model differently. | 
					
						
						|  | allow_non_tensor: allow inputs/outputs to contain non-tensor objects. | 
					
						
						|  | This option will filter out non-tensor objects to make the | 
					
						
						|  | model traceable, but ``inputs_schema``/``outputs_schema`` cannot be | 
					
						
						|  | used anymore because inputs/outputs cannot be rebuilt from pure tensors. | 
					
						
						|  | This is useful when you're only interested in the single trace of | 
					
						
						|  | execution (e.g. for flop count), but not interested in | 
					
						
						|  | generalizing the traced graph to new inputs. | 
					
						
						|  | """ | 
					
						
						|  | super().__init__() | 
					
						
						|  | if isinstance(model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)): | 
					
						
						|  | model = model.module | 
					
						
						|  | self.model = model | 
					
						
						|  | if not isinstance(inputs, tuple): | 
					
						
						|  | inputs = (inputs,) | 
					
						
						|  | self.inputs = inputs | 
					
						
						|  | self.allow_non_tensor = allow_non_tensor | 
					
						
						|  |  | 
					
						
						|  | if inference_func is None: | 
					
						
						|  | inference_func = lambda model, *inputs: model(*inputs) | 
					
						
						|  | self.inference_func = inference_func | 
					
						
						|  |  | 
					
						
						|  | self.flattened_inputs, self.inputs_schema = flatten_to_tuple(inputs) | 
					
						
						|  |  | 
					
						
						|  | if all(isinstance(x, torch.Tensor) for x in self.flattened_inputs): | 
					
						
						|  | return | 
					
						
						|  | if self.allow_non_tensor: | 
					
						
						|  | self.flattened_inputs = tuple( | 
					
						
						|  | [x for x in self.flattened_inputs if isinstance(x, torch.Tensor)] | 
					
						
						|  | ) | 
					
						
						|  | self.inputs_schema = None | 
					
						
						|  | else: | 
					
						
						|  | for input in self.flattened_inputs: | 
					
						
						|  | if not isinstance(input, torch.Tensor): | 
					
						
						|  | raise ValueError( | 
					
						
						|  | "Inputs for tracing must only contain tensors. " | 
					
						
						|  | f"Got a {type(input)} instead." | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, *args: torch.Tensor): | 
					
						
						|  | with torch.no_grad(), patch_builtin_len(): | 
					
						
						|  | if self.inputs_schema is not None: | 
					
						
						|  | inputs_orig_format = self.inputs_schema(args) | 
					
						
						|  | else: | 
					
						
						|  | if len(args) != len(self.flattened_inputs) or any( | 
					
						
						|  | x is not y for x, y in zip(args, self.flattened_inputs) | 
					
						
						|  | ): | 
					
						
						|  | raise ValueError( | 
					
						
						|  | "TracingAdapter does not contain valid inputs_schema." | 
					
						
						|  | " So it cannot generalize to other inputs and must be" | 
					
						
						|  | " traced with `.flattened_inputs`." | 
					
						
						|  | ) | 
					
						
						|  | inputs_orig_format = self.inputs | 
					
						
						|  |  | 
					
						
						|  | outputs = self.inference_func(self.model, *inputs_orig_format) | 
					
						
						|  | flattened_outputs, schema = flatten_to_tuple(outputs) | 
					
						
						|  |  | 
					
						
						|  | flattened_output_tensors = tuple( | 
					
						
						|  | [x for x in flattened_outputs if isinstance(x, torch.Tensor)] | 
					
						
						|  | ) | 
					
						
						|  | if len(flattened_output_tensors) < len(flattened_outputs): | 
					
						
						|  | if self.allow_non_tensor: | 
					
						
						|  | flattened_outputs = flattened_output_tensors | 
					
						
						|  | self.outputs_schema = None | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | "Model cannot be traced because some model outputs " | 
					
						
						|  | "cannot flatten to tensors." | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | if self.outputs_schema is None: | 
					
						
						|  | self.outputs_schema = schema | 
					
						
						|  | else: | 
					
						
						|  | assert self.outputs_schema == schema, ( | 
					
						
						|  | "Model should always return outputs with the same " | 
					
						
						|  | "structure so it can be traced!" | 
					
						
						|  | ) | 
					
						
						|  | return flattened_outputs | 
					
						
						|  |  | 
					
						
						|  | def _create_wrapper(self, traced_model): | 
					
						
						|  | """ | 
					
						
						|  | Return a function that has an input/output interface the same as the | 
					
						
						|  | original model, but it calls the given traced model under the hood. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def forward(*args): | 
					
						
						|  | flattened_inputs, _ = flatten_to_tuple(args) | 
					
						
						|  | flattened_outputs = traced_model(*flattened_inputs) | 
					
						
						|  | return self.outputs_schema(flattened_outputs) | 
					
						
						|  |  | 
					
						
						|  | return forward | 
					
						
						|  |  |