Spaces:
Running
on
Zero
Running
on
Zero
| import torch.nn as nn | |
| from einops import rearrange | |
| class PixelShuffleND(nn.Module): | |
| def __init__(self, dims, upscale_factors=(2, 2, 2)): | |
| super().__init__() | |
| assert dims in [1, 2, 3], "dims must be 1, 2, or 3" | |
| self.dims = dims | |
| self.upscale_factors = upscale_factors | |
| def forward(self, x): | |
| if self.dims == 3: | |
| return rearrange( | |
| x, | |
| "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", | |
| p1=self.upscale_factors[0], | |
| p2=self.upscale_factors[1], | |
| p3=self.upscale_factors[2], | |
| ) | |
| elif self.dims == 2: | |
| return rearrange( | |
| x, | |
| "b (c p1 p2) h w -> b c (h p1) (w p2)", | |
| p1=self.upscale_factors[0], | |
| p2=self.upscale_factors[1], | |
| ) | |
| elif self.dims == 1: | |
| return rearrange( | |
| x, | |
| "b (c p1) f h w -> b c (f p1) h w", | |
| p1=self.upscale_factors[0], | |
| ) | |