Spaces:
Build error
Build error
| import jax | |
| from jax import lax | |
| from jax.nn import initializers | |
| import jax.numpy as jnp | |
| import flax | |
| from flax.linen.module import merge_param | |
| import flax.linen as nn | |
| from typing import Callable, Iterable, Optional, Tuple, Union, Any | |
| import functools | |
| import pickle | |
| from . import utils | |
| PRNGKey = Any | |
| Array = Any | |
| Shape = Tuple[int] | |
| Dtype = Any | |
| class InceptionV3(nn.Module): | |
| """ | |
| InceptionV3 network. | |
| Reference: https://arxiv.org/abs/1512.00567 | |
| Ported mostly from: https://github.com/pytorch/vision/blob/master/torchvision/models/inception.py | |
| Attributes: | |
| include_head (bool): If True, include classifier head. | |
| num_classes (int): Number of classes. | |
| pretrained (bool): If True, use pretrained weights. | |
| transform_input (bool): If True, preprocesses the input according to the method with which it | |
| was trained on ImageNet. | |
| aux_logits (bool): If True, add an auxiliary branch that can improve training. | |
| dtype (str): Data type. | |
| """ | |
| include_head: bool=False | |
| num_classes: int=1000 | |
| pretrained: bool=False | |
| transform_input: bool=False | |
| aux_logits: bool=False | |
| ckpt_path: str='https://www.dropbox.com/s/0zo4pd6cfwgzem7/inception_v3_weights_fid.pickle?dl=1' | |
| dtype: str='float32' | |
| def setup(self): | |
| if self.pretrained: | |
| ckpt_file = utils.download(self.ckpt_path) | |
| self.params_dict = pickle.load(open(ckpt_file, 'rb')) | |
| self.num_classes_ = 1000 | |
| else: | |
| self.params_dict = None | |
| self.num_classes_ = self.num_classes | |
| def __call__(self, x, train=True, rng=jax.random.PRNGKey(0)): | |
| """ | |
| Args: | |
| x (tensor): Input image, shape [B, H, W, C]. | |
| train (bool): If True, training mode. | |
| rng (jax.random.PRNGKey): Random seed. | |
| """ | |
| x = self._transform_input(x) | |
| x = BasicConv2d(out_channels=32, | |
| kernel_size=(3, 3), | |
| strides=(2, 2), | |
| params_dict=utils.get(self.params_dict, 'Conv2d_1a_3x3'), | |
| dtype=self.dtype)(x, train) | |
| x = BasicConv2d(out_channels=32, | |
| kernel_size=(3, 3), | |
| params_dict=utils.get(self.params_dict, 'Conv2d_2a_3x3'), | |
| dtype=self.dtype)(x, train) | |
| x = BasicConv2d(out_channels=64, | |
| kernel_size=(3, 3), | |
| padding=((1, 1), (1, 1)), | |
| params_dict=utils.get(self.params_dict, 'Conv2d_2b_3x3'), | |
| dtype=self.dtype)(x, train) | |
| x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2)) | |
| x = BasicConv2d(out_channels=80, | |
| kernel_size=(1, 1), | |
| params_dict=utils.get(self.params_dict, 'Conv2d_3b_1x1'), | |
| dtype=self.dtype)(x, train) | |
| x = BasicConv2d(out_channels=192, | |
| kernel_size=(3, 3), | |
| params_dict=utils.get(self.params_dict, 'Conv2d_4a_3x3'), | |
| dtype=self.dtype)(x, train) | |
| x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2)) | |
| x = InceptionA(pool_features=32, | |
| params_dict=utils.get(self.params_dict, 'Mixed_5b'), | |
| dtype=self.dtype)(x, train) | |
| x = InceptionA(pool_features=64, | |
| params_dict=utils.get(self.params_dict, 'Mixed_5c'), | |
| dtype=self.dtype)(x, train) | |
| x = InceptionA(pool_features=64, | |
| params_dict=utils.get(self.params_dict, 'Mixed_5d'), | |
| dtype=self.dtype)(x, train) | |
| x = InceptionB(params_dict=utils.get(self.params_dict, 'Mixed_6a'), | |
| dtype=self.dtype)(x, train) | |
| x = InceptionC(channels_7x7=128, | |
| params_dict=utils.get(self.params_dict, 'Mixed_6b'), | |
| dtype=self.dtype)(x, train) | |
| x = InceptionC(channels_7x7=160, | |
| params_dict=utils.get(self.params_dict, 'Mixed_6c'), | |
| dtype=self.dtype)(x, train) | |
| x = InceptionC(channels_7x7=160, | |
| params_dict=utils.get(self.params_dict, 'Mixed_6d'), | |
| dtype=self.dtype)(x, train) | |
| x = InceptionC(channels_7x7=192, | |
| params_dict=utils.get(self.params_dict, 'Mixed_6e'), | |
| dtype=self.dtype)(x, train) | |
| aux = None | |
| if self.aux_logits and train: | |
| aux = InceptionAux(num_classes=self.num_classes_, | |
| params_dict=utils.get(self.params_dict, 'AuxLogits'), | |
| dtype=self.dtype)(x, train) | |
| x = InceptionD(params_dict=utils.get(self.params_dict, 'Mixed_7a'), | |
| dtype=self.dtype)(x, train) | |
| x = InceptionE(avg_pool, params_dict=utils.get(self.params_dict, 'Mixed_7b'), | |
| dtype=self.dtype)(x, train) | |
| # Following the implementation by @mseitzer, we use max pooling instead | |
| # of average pooling here. | |
| # See: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/inception.py#L320 | |
| x = InceptionE(nn.max_pool, params_dict=utils.get(self.params_dict, 'Mixed_7c'), | |
| dtype=self.dtype)(x, train) | |
| x = jnp.mean(x, axis=(1, 2), keepdims=True) | |
| if not self.include_head: | |
| return x | |
| x = nn.Dropout(rate=0.5)(x, deterministic=not train, rng=rng) | |
| x = jnp.reshape(x, newshape=(x.shape[0], -1)) | |
| x = Dense(features=self.num_classes_, | |
| params_dict=utils.get(self.params_dict, 'fc'), | |
| dtype=self.dtype)(x) | |
| if self.aux_logits: | |
| return x, aux | |
| return x | |
| def _transform_input(self, x): | |
| if self.transform_input: | |
| x_ch0 = jnp.expand_dims(x[..., 0], axis=-1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 | |
| x_ch1 = jnp.expand_dims(x[..., 1], axis=-1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 | |
| x_ch2 = jnp.expand_dims(x[..., 2], axis=-1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 | |
| x = jnp.concatenate((x_ch0, x_ch1, x_ch2), axis=-1) | |
| return x | |
| class Dense(nn.Module): | |
| features: int | |
| kernel_init: functools.partial=nn.initializers.lecun_normal() | |
| bias_init: functools.partial=nn.initializers.zeros | |
| params_dict: dict=None | |
| dtype: str='float32' | |
| def __call__(self, x): | |
| x = nn.Dense(features=self.features, | |
| kernel_init=self.kernel_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['kernel']), | |
| bias_init=self.bias_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['bias']))(x) | |
| return x | |
| class BasicConv2d(nn.Module): | |
| out_channels: int | |
| kernel_size: Union[int, Iterable[int]]=(3, 3) | |
| strides: Optional[Iterable[int]]=(1, 1) | |
| padding: Union[str, Iterable[Tuple[int, int]]]='valid' | |
| use_bias: bool=False | |
| kernel_init: functools.partial=nn.initializers.lecun_normal() | |
| bias_init: functools.partial=nn.initializers.zeros | |
| params_dict: dict=None | |
| dtype: str='float32' | |
| def __call__(self, x, train=True): | |
| x = nn.Conv(features=self.out_channels, | |
| kernel_size=self.kernel_size, | |
| strides=self.strides, | |
| padding=self.padding, | |
| use_bias=self.use_bias, | |
| kernel_init=self.kernel_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['conv']['kernel']), | |
| bias_init=self.bias_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['conv']['bias']), | |
| dtype=self.dtype)(x) | |
| if self.params_dict is None: | |
| x = BatchNorm(epsilon=0.001, | |
| momentum=0.1, | |
| use_running_average=not train, | |
| dtype=self.dtype)(x) | |
| else: | |
| x = BatchNorm(epsilon=0.001, | |
| momentum=0.1, | |
| bias_init=lambda *_ : jnp.array(self.params_dict['bn']['bias']), | |
| scale_init=lambda *_ : jnp.array(self.params_dict['bn']['scale']), | |
| mean_init=lambda *_ : jnp.array(self.params_dict['bn']['mean']), | |
| var_init=lambda *_ : jnp.array(self.params_dict['bn']['var']), | |
| use_running_average=not train, | |
| dtype=self.dtype)(x) | |
| x = jax.nn.relu(x) | |
| return x | |
| class InceptionA(nn.Module): | |
| pool_features: int | |
| params_dict: dict=None | |
| dtype: str='float32' | |
| def __call__(self, x, train=True): | |
| branch1x1 = BasicConv2d(out_channels=64, | |
| kernel_size=(1, 1), | |
| params_dict=utils.get(self.params_dict, 'branch1x1'), | |
| dtype=self.dtype)(x, train) | |
| branch5x5 = BasicConv2d(out_channels=48, | |
| kernel_size=(1, 1), | |
| params_dict=utils.get(self.params_dict, 'branch5x5_1'), | |
| dtype=self.dtype)(x, train) | |
| branch5x5 = BasicConv2d(out_channels=64, | |
| kernel_size=(5, 5), | |
| padding=((2, 2), (2, 2)), | |
| params_dict=utils.get(self.params_dict, 'branch5x5_2'), | |
| dtype=self.dtype)(branch5x5, train) | |
| branch3x3dbl = BasicConv2d(out_channels=64, | |
| kernel_size=(1, 1), | |
| params_dict=utils.get(self.params_dict, 'branch3x3dbl_1'), | |
| dtype=self.dtype)(x, train) | |
| branch3x3dbl = BasicConv2d(out_channels=96, | |
| kernel_size=(3, 3), | |
| padding=((1, 1), (1, 1)), | |
| params_dict=utils.get(self.params_dict, 'branch3x3dbl_2'), | |
| dtype=self.dtype)(branch3x3dbl, train) | |
| branch3x3dbl = BasicConv2d(out_channels=96, | |
| kernel_size=(3, 3), | |
| padding=((1, 1), (1, 1)), | |
| params_dict=utils.get(self.params_dict, 'branch3x3dbl_3'), | |
| dtype=self.dtype)(branch3x3dbl, train) | |
| branch_pool = avg_pool(x, window_shape=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1))) | |
| branch_pool = BasicConv2d(out_channels=self.pool_features, | |
| kernel_size=(1, 1), | |
| params_dict=utils.get(self.params_dict, 'branch_pool'), | |
| dtype=self.dtype)(branch_pool, train) | |
| output = jnp.concatenate((branch1x1, branch5x5, branch3x3dbl, branch_pool), axis=-1) | |
| return output | |
| class InceptionB(nn.Module): | |
| params_dict: dict=None | |
| dtype: str='float32' | |
| def __call__(self, x, train=True): | |
| branch3x3 = BasicConv2d(out_channels=384, | |
| kernel_size=(3, 3), | |
| strides=(2, 2), | |
| params_dict=utils.get(self.params_dict, 'branch3x3'), | |
| dtype=self.dtype)(x, train) | |
| branch3x3dbl = BasicConv2d(out_channels=64, | |
| kernel_size=(1, 1), | |
| params_dict=utils.get(self.params_dict, 'branch3x3dbl_1'), | |
| dtype=self.dtype)(x, train) | |
| branch3x3dbl = BasicConv2d(out_channels=96, | |
| kernel_size=(3, 3), | |
| padding=((1, 1), (1, 1)), | |
| params_dict=utils.get(self.params_dict, 'branch3x3dbl_2'), | |
| dtype=self.dtype)(branch3x3dbl, train) | |
| branch3x3dbl = BasicConv2d(out_channels=96, | |
| kernel_size=(3, 3), | |
| strides=(2, 2), | |
| params_dict=utils.get(self.params_dict, 'branch3x3dbl_3'), | |
| dtype=self.dtype)(branch3x3dbl, train) | |
| branch_pool = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2)) | |
| output = jnp.concatenate((branch3x3, branch3x3dbl, branch_pool), axis=-1) | |
| return output | |
| class InceptionC(nn.Module): | |
| channels_7x7: int | |
| params_dict: dict=None | |
| dtype: str='float32' | |
| def __call__(self, x, train=True): | |
| branch1x1 = BasicConv2d(out_channels=192, | |
| kernel_size=(1, 1), | |
| params_dict=utils.get(self.params_dict, 'branch1x1'), | |
| dtype=self.dtype)(x, train) | |
| branch7x7 = BasicConv2d(out_channels=self.channels_7x7, | |
| kernel_size=(1, 1), | |
| params_dict=utils.get(self.params_dict, 'branch7x7_1'), | |
| dtype=self.dtype)(x, train) | |
| branch7x7 = BasicConv2d(out_channels=self.channels_7x7, | |
| kernel_size=(1, 7), | |
| padding=((0, 0), (3, 3)), | |
| params_dict=utils.get(self.params_dict, 'branch7x7_2'), | |
| dtype=self.dtype)(branch7x7, train) | |
| branch7x7 = BasicConv2d(out_channels=192, | |
| kernel_size=(7, 1), | |
| padding=((3, 3), (0, 0)), | |
| params_dict=utils.get(self.params_dict, 'branch7x7_3'), | |
| dtype=self.dtype)(branch7x7, train) | |
| branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7, | |
| kernel_size=(1, 1), | |
| params_dict=utils.get(self.params_dict, 'branch7x7dbl_1'), | |
| dtype=self.dtype)(x, train) | |
| branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7, | |
| kernel_size=(7, 1), | |
| padding=((3, 3), (0, 0)), | |
| params_dict=utils.get(self.params_dict, 'branch7x7dbl_2'), | |
| dtype=self.dtype)(branch7x7dbl, train) | |
| branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7, | |
| kernel_size=(1, 7), | |
| padding=((0, 0), (3, 3)), | |
| params_dict=utils.get(self.params_dict, 'branch7x7dbl_3'), | |
| dtype=self.dtype)(branch7x7dbl, train) | |
| branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7, | |
| kernel_size=(7, 1), | |
| padding=((3, 3), (0, 0)), | |
| params_dict=utils.get(self.params_dict, 'branch7x7dbl_4'), | |
| dtype=self.dtype)(branch7x7dbl, train) | |
| branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7, | |
| kernel_size=(1, 7), | |
| padding=((0, 0), (3, 3)), | |
| params_dict=utils.get(self.params_dict, 'branch7x7dbl_5'), | |
| dtype=self.dtype)(branch7x7dbl, train) | |
| branch_pool = avg_pool(x, window_shape=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1))) | |
| branch_pool = BasicConv2d(out_channels=192, | |
| kernel_size=(1, 1), | |
| params_dict=utils.get(self.params_dict, 'branch_pool'), | |
| dtype=self.dtype)(branch_pool, train) | |
| output = jnp.concatenate((branch1x1, branch7x7, branch7x7dbl, branch_pool), axis=-1) | |
| return output | |
| class InceptionD(nn.Module): | |
| params_dict: dict=None | |
| dtype: str='float32' | |
| def __call__(self, x, train=True): | |
| branch3x3 = BasicConv2d(out_channels=192, | |
| kernel_size=(1, 1), | |
| params_dict=utils.get(self.params_dict, 'branch3x3_1'), | |
| dtype=self.dtype)(x, train) | |
| branch3x3 = BasicConv2d(out_channels=320, | |
| kernel_size=(3, 3), | |
| strides=(2, 2), | |
| params_dict=utils.get(self.params_dict, 'branch3x3_2'), | |
| dtype=self.dtype)(branch3x3, train) | |
| branch7x7x3 = BasicConv2d(out_channels=192, | |
| kernel_size=(1, 1), | |
| params_dict=utils.get(self.params_dict, 'branch7x7x3_1'), | |
| dtype=self.dtype)(x, train) | |
| branch7x7x3 = BasicConv2d(out_channels=192, | |
| kernel_size=(1, 7), | |
| padding=((0, 0), (3, 3)), | |
| params_dict=utils.get(self.params_dict, 'branch7x7x3_2'), | |
| dtype=self.dtype)(branch7x7x3, train) | |
| branch7x7x3 = BasicConv2d(out_channels=192, | |
| kernel_size=(7, 1), | |
| padding=((3, 3), (0, 0)), | |
| params_dict=utils.get(self.params_dict, 'branch7x7x3_3'), | |
| dtype=self.dtype)(branch7x7x3, train) | |
| branch7x7x3 = BasicConv2d(out_channels=192, | |
| kernel_size=(3, 3), | |
| strides=(2, 2), | |
| params_dict=utils.get(self.params_dict, 'branch7x7x3_4'), | |
| dtype=self.dtype)(branch7x7x3, train) | |
| branch_pool = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2)) | |
| output = jnp.concatenate((branch3x3, branch7x7x3, branch_pool), axis=-1) | |
| return output | |
| class InceptionE(nn.Module): | |
| pooling: Callable | |
| params_dict: dict=None | |
| dtype: str='float32' | |
| def __call__(self, x, train=True): | |
| branch1x1 = BasicConv2d(out_channels=320, | |
| kernel_size=(1, 1), | |
| params_dict=utils.get(self.params_dict, 'branch1x1'), | |
| dtype=self.dtype)(x, train) | |
| branch3x3 = BasicConv2d(out_channels=384, | |
| kernel_size=(1, 1), | |
| params_dict=utils.get(self.params_dict, 'branch3x3_1'), | |
| dtype=self.dtype)(x, train) | |
| branch3x3_a = BasicConv2d(out_channels=384, | |
| kernel_size=(1, 3), | |
| padding=((0, 0), (1, 1)), | |
| params_dict=utils.get(self.params_dict, 'branch3x3_2a'), | |
| dtype=self.dtype)(branch3x3, train) | |
| branch3x3_b = BasicConv2d(out_channels=384, | |
| kernel_size=(3, 1), | |
| padding=((1, 1), (0, 0)), | |
| params_dict=utils.get(self.params_dict, 'branch3x3_2b'), | |
| dtype=self.dtype)(branch3x3, train) | |
| branch3x3 = jnp.concatenate((branch3x3_a, branch3x3_b), axis=-1) | |
| branch3x3dbl = BasicConv2d(out_channels=448, | |
| kernel_size=(1, 1), | |
| params_dict=utils.get(self.params_dict, 'branch3x3dbl_1'), | |
| dtype=self.dtype)(x, train) | |
| branch3x3dbl = BasicConv2d(out_channels=384, | |
| kernel_size=(3, 3), | |
| padding=((1, 1), (1, 1)), | |
| params_dict=utils.get(self.params_dict, 'branch3x3dbl_2'), | |
| dtype=self.dtype)(branch3x3dbl, train) | |
| branch3x3dbl_a = BasicConv2d(out_channels=384, | |
| kernel_size=(1, 3), | |
| padding=((0, 0), (1, 1)), | |
| params_dict=utils.get(self.params_dict, 'branch3x3dbl_3a'), | |
| dtype=self.dtype)(branch3x3dbl, train) | |
| branch3x3dbl_b = BasicConv2d(out_channels=384, | |
| kernel_size=(3, 1), | |
| padding=((1, 1), (0, 0)), | |
| params_dict=utils.get(self.params_dict, 'branch3x3dbl_3b'), | |
| dtype=self.dtype)(branch3x3dbl, train) | |
| branch3x3dbl = jnp.concatenate((branch3x3dbl_a, branch3x3dbl_b), axis=-1) | |
| branch_pool = self.pooling(x, window_shape=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1))) | |
| branch_pool = BasicConv2d(out_channels=192, | |
| kernel_size=(1, 1), | |
| params_dict=utils.get(self.params_dict, 'branch_pool'), | |
| dtype=self.dtype)(branch_pool, train) | |
| output = jnp.concatenate((branch1x1, branch3x3, branch3x3dbl, branch_pool), axis=-1) | |
| return output | |
| class InceptionAux(nn.Module): | |
| num_classes: int | |
| kernel_init: functools.partial=nn.initializers.lecun_normal() | |
| bias_init: functools.partial=nn.initializers.zeros | |
| params_dict: dict=None | |
| dtype: str='float32' | |
| def __call__(self, x, train=True): | |
| x = avg_pool(x, window_shape=(5, 5), strides=(3, 3)) | |
| x = BasicConv2d(out_channels=128, | |
| kernel_size=(1, 1), | |
| params_dict=utils.get(self.params_dict, 'conv0'), | |
| dtype=self.dtype)(x, train) | |
| x = BasicConv2d(out_channels=768, | |
| kernel_size=(5, 5), | |
| params_dict=utils.get(self.params_dict, 'conv1'), | |
| dtype=self.dtype)(x, train) | |
| x = jnp.mean(x, axis=(1, 2)) | |
| x = jnp.reshape(x, newshape=(x.shape[0], -1)) | |
| x = Dense(features=self.num_classes, | |
| params_dict=utils.get(self.params_dict, 'fc'), | |
| dtype=self.dtype)(x) | |
| return x | |
| def _absolute_dims(rank, dims): | |
| return tuple([rank + dim if dim < 0 else dim for dim in dims]) | |
| class BatchNorm(nn.Module): | |
| """BatchNorm Module. | |
| Taken from: https://github.com/google/flax/blob/master/flax/linen/normalization.py | |
| Attributes: | |
| use_running_average: if True, the statistics stored in batch_stats | |
| will be used instead of computing the batch statistics on the input. | |
| axis: the feature or non-batch axis of the input. | |
| momentum: decay rate for the exponential moving average of the batch statistics. | |
| epsilon: a small float added to variance to avoid dividing by zero. | |
| dtype: the dtype of the computation (default: float32). | |
| use_bias: if True, bias (beta) is added. | |
| use_scale: if True, multiply by scale (gamma). | |
| When the next layer is linear (also e.g. nn.relu), this can be disabled | |
| since the scaling will be done by the next layer. | |
| bias_init: initializer for bias, by default, zero. | |
| scale_init: initializer for scale, by default, one. | |
| axis_name: the axis name used to combine batch statistics from multiple | |
| devices. See `jax.pmap` for a description of axis names (default: None). | |
| axis_index_groups: groups of axis indices within that named axis | |
| representing subsets of devices to reduce over (default: None). For | |
| example, `[[0, 1], [2, 3]]` would independently batch-normalize over | |
| the examples on the first two and last two devices. See `jax.lax.psum` | |
| for more details. | |
| """ | |
| use_running_average: Optional[bool] = None | |
| axis: int = -1 | |
| momentum: float = 0.99 | |
| epsilon: float = 1e-5 | |
| dtype: Dtype = jnp.float32 | |
| use_bias: bool = True | |
| use_scale: bool = True | |
| bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros | |
| scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones | |
| mean_init: Callable[[Shape], Array] = lambda s: jnp.zeros(s, jnp.float32) | |
| var_init: Callable[[Shape], Array] = lambda s: jnp.ones(s, jnp.float32) | |
| axis_name: Optional[str] = None | |
| axis_index_groups: Any = None | |
| def __call__(self, x, use_running_average: Optional[bool] = None): | |
| """Normalizes the input using batch statistics. | |
| NOTE: | |
| During initialization (when parameters are mutable) the running average | |
| of the batch statistics will not be updated. Therefore, the inputs | |
| fed during initialization don't need to match that of the actual input | |
| distribution and the reduction axis (set with `axis_name`) does not have | |
| to exist. | |
| Args: | |
| x: the input to be normalized. | |
| use_running_average: if true, the statistics stored in batch_stats | |
| will be used instead of computing the batch statistics on the input. | |
| Returns: | |
| Normalized inputs (the same shape as inputs). | |
| """ | |
| use_running_average = merge_param( | |
| 'use_running_average', self.use_running_average, use_running_average) | |
| x = jnp.asarray(x, jnp.float32) | |
| axis = self.axis if isinstance(self.axis, tuple) else (self.axis,) | |
| axis = _absolute_dims(x.ndim, axis) | |
| feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape)) | |
| reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis) | |
| reduction_axis = tuple(i for i in range(x.ndim) if i not in axis) | |
| # see NOTE above on initialization behavior | |
| initializing = self.is_mutable_collection('params') | |
| ra_mean = self.variable('batch_stats', 'mean', | |
| self.mean_init, | |
| reduced_feature_shape) | |
| ra_var = self.variable('batch_stats', 'var', | |
| self.var_init, | |
| reduced_feature_shape) | |
| if use_running_average: | |
| mean, var = ra_mean.value, ra_var.value | |
| else: | |
| mean = jnp.mean(x, axis=reduction_axis, keepdims=False) | |
| mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False) | |
| if self.axis_name is not None and not initializing: | |
| concatenated_mean = jnp.concatenate([mean, mean2]) | |
| mean, mean2 = jnp.split( | |
| lax.pmean( | |
| concatenated_mean, | |
| axis_name=self.axis_name, | |
| axis_index_groups=self.axis_index_groups), 2) | |
| var = mean2 - lax.square(mean) | |
| if not initializing: | |
| ra_mean.value = self.momentum * ra_mean.value + (1 - self.momentum) * mean | |
| ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var | |
| y = x - mean.reshape(feature_shape) | |
| mul = lax.rsqrt(var + self.epsilon) | |
| if self.use_scale: | |
| scale = self.param('scale', | |
| self.scale_init, | |
| reduced_feature_shape).reshape(feature_shape) | |
| mul = mul * scale | |
| y = y * mul | |
| if self.use_bias: | |
| bias = self.param('bias', | |
| self.bias_init, | |
| reduced_feature_shape).reshape(feature_shape) | |
| y = y + bias | |
| return jnp.asarray(y, self.dtype) | |
| def pool(inputs, init, reduce_fn, window_shape, strides, padding): | |
| """ | |
| Taken from: https://github.com/google/flax/blob/main/flax/linen/pooling.py | |
| Helper function to define pooling functions. | |
| Pooling functions are implemented using the ReduceWindow XLA op. | |
| NOTE: Be aware that pooling is not generally differentiable. | |
| That means providing a reduce_fn that is differentiable does not imply | |
| that pool is differentiable. | |
| Args: | |
| inputs: input data with dimensions (batch, window dims..., features). | |
| init: the initial value for the reduction | |
| reduce_fn: a reduce function of the form `(T, T) -> T`. | |
| window_shape: a shape tuple defining the window to reduce over. | |
| strides: a sequence of `n` integers, representing the inter-window | |
| strides. | |
| padding: either the string `'SAME'`, the string `'VALID'`, or a sequence | |
| of `n` `(low, high)` integer pairs that give the padding to apply before | |
| and after each spatial dimension. | |
| Returns: | |
| The output of the reduction for each window slice. | |
| """ | |
| strides = strides or (1,) * len(window_shape) | |
| assert len(window_shape) == len(strides), ( | |
| f"len({window_shape}) == len({strides})") | |
| strides = (1,) + strides + (1,) | |
| dims = (1,) + window_shape + (1,) | |
| is_single_input = False | |
| if inputs.ndim == len(dims) - 1: | |
| # add singleton batch dimension because lax.reduce_window always | |
| # needs a batch dimension. | |
| inputs = inputs[None] | |
| is_single_input = True | |
| assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})" | |
| if not isinstance(padding, str): | |
| padding = tuple(map(tuple, padding)) | |
| assert(len(padding) == len(window_shape)), ( | |
| f"padding {padding} must specify pads for same number of dims as " | |
| f"window_shape {window_shape}") | |
| assert(all([len(x) == 2 for x in padding])), ( | |
| f"each entry in padding {padding} must be length 2") | |
| padding = ((0,0),) + padding + ((0,0),) | |
| y = jax.lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding) | |
| if is_single_input: | |
| y = jnp.squeeze(y, axis=0) | |
| return y | |
| def avg_pool(inputs, window_shape, strides=None, padding='VALID'): | |
| """ | |
| Pools the input by taking the average over a window. | |
| In comparison to flax.linen.avg_pool, this pooling operation does not | |
| consider the padded zero's for the average computation. | |
| Args: | |
| inputs: input data with dimensions (batch, window dims..., features). | |
| window_shape: a shape tuple defining the window to reduce over. | |
| strides: a sequence of `n` integers, representing the inter-window | |
| strides (default: `(1, ..., 1)`). | |
| padding: either the string `'SAME'`, the string `'VALID'`, or a sequence | |
| of `n` `(low, high)` integer pairs that give the padding to apply before | |
| and after each spatial dimension (default: `'VALID'`). | |
| Returns: | |
| The average for each window slice. | |
| """ | |
| assert inputs.ndim == 4 | |
| assert len(window_shape) == 2 | |
| y = pool(inputs, 0., jax.lax.add, window_shape, strides, padding) | |
| ones = jnp.ones(shape=(1, inputs.shape[1], inputs.shape[2], 1)).astype(inputs.dtype) | |
| counts = jax.lax.conv_general_dilated(ones, | |
| jnp.expand_dims(jnp.ones(window_shape).astype(inputs.dtype), axis=(-2, -1)), | |
| window_strides=(1, 1), | |
| padding=((1, 1), (1, 1)), | |
| dimension_numbers=nn.linear._conv_dimension_numbers(ones.shape), | |
| feature_group_count=1) | |
| y = y / counts | |
| return y | |