Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.models as tvm | |
| class ResNet18(nn.Module): | |
| def __init__(self, pretrained=False) -> None: | |
| super().__init__() | |
| self.net = tvm.resnet18(pretrained=pretrained) | |
| def forward(self, x): | |
| self = self.net | |
| x1 = x | |
| x = self.conv1(x1) | |
| x = self.bn1(x) | |
| x2 = self.relu(x) | |
| x = self.maxpool(x2) | |
| x4 = self.layer1(x) | |
| x8 = self.layer2(x4) | |
| x16 = self.layer3(x8) | |
| x32 = self.layer4(x16) | |
| return {32: x32, 16: x16, 8: x8, 4: x4, 2: x2, 1: x1} | |
| def train(self, mode=True): | |
| super().train(mode) | |
| for m in self.modules(): | |
| if isinstance(m, nn.BatchNorm2d): | |
| m.eval() | |
| pass | |
| class ResNet50(nn.Module): | |
| def __init__( | |
| self, | |
| pretrained=False, | |
| high_res=False, | |
| weights=None, | |
| dilation=None, | |
| freeze_bn=True, | |
| anti_aliased=False, | |
| ) -> None: | |
| super().__init__() | |
| if dilation is None: | |
| dilation = [False, False, False] | |
| if anti_aliased: | |
| pass | |
| else: | |
| if weights is not None: | |
| self.net = tvm.resnet50( | |
| weights=weights, replace_stride_with_dilation=dilation | |
| ) | |
| else: | |
| self.net = tvm.resnet50( | |
| pretrained=pretrained, replace_stride_with_dilation=dilation | |
| ) | |
| self.high_res = high_res | |
| self.freeze_bn = freeze_bn | |
| def forward(self, x): | |
| net = self.net | |
| feats = {1: x} | |
| x = net.conv1(x) | |
| x = net.bn1(x) | |
| x = net.relu(x) | |
| feats[2] = x | |
| x = net.maxpool(x) | |
| x = net.layer1(x) | |
| feats[4] = x | |
| x = net.layer2(x) | |
| feats[8] = x | |
| x = net.layer3(x) | |
| feats[16] = x | |
| x = net.layer4(x) | |
| feats[32] = x | |
| return feats | |
| def train(self, mode=True): | |
| super().train(mode) | |
| if self.freeze_bn: | |
| for m in self.modules(): | |
| if isinstance(m, nn.BatchNorm2d): | |
| m.eval() | |
| pass | |
| class ResNet101(nn.Module): | |
| def __init__(self, pretrained=False, high_res=False, weights=None) -> None: | |
| super().__init__() | |
| if weights is not None: | |
| self.net = tvm.resnet101(weights=weights) | |
| else: | |
| self.net = tvm.resnet101(pretrained=pretrained) | |
| self.high_res = high_res | |
| self.scale_factor = 1 if not high_res else 1.5 | |
| def forward(self, x): | |
| net = self.net | |
| feats = {1: x} | |
| sf = self.scale_factor | |
| if self.high_res: | |
| x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic") | |
| x = net.conv1(x) | |
| x = net.bn1(x) | |
| x = net.relu(x) | |
| feats[2] = ( | |
| x | |
| if not self.high_res | |
| else F.interpolate( | |
| x, scale_factor=1 / sf, align_corners=False, mode="bilinear" | |
| ) | |
| ) | |
| x = net.maxpool(x) | |
| x = net.layer1(x) | |
| feats[4] = ( | |
| x | |
| if not self.high_res | |
| else F.interpolate( | |
| x, scale_factor=1 / sf, align_corners=False, mode="bilinear" | |
| ) | |
| ) | |
| x = net.layer2(x) | |
| feats[8] = ( | |
| x | |
| if not self.high_res | |
| else F.interpolate( | |
| x, scale_factor=1 / sf, align_corners=False, mode="bilinear" | |
| ) | |
| ) | |
| x = net.layer3(x) | |
| feats[16] = ( | |
| x | |
| if not self.high_res | |
| else F.interpolate( | |
| x, scale_factor=1 / sf, align_corners=False, mode="bilinear" | |
| ) | |
| ) | |
| x = net.layer4(x) | |
| feats[32] = ( | |
| x | |
| if not self.high_res | |
| else F.interpolate( | |
| x, scale_factor=1 / sf, align_corners=False, mode="bilinear" | |
| ) | |
| ) | |
| return feats | |
| def train(self, mode=True): | |
| super().train(mode) | |
| for m in self.modules(): | |
| if isinstance(m, nn.BatchNorm2d): | |
| m.eval() | |
| pass | |
| class WideResNet50(nn.Module): | |
| def __init__(self, pretrained=False, high_res=False, weights=None) -> None: | |
| super().__init__() | |
| if weights is not None: | |
| self.net = tvm.wide_resnet50_2(weights=weights) | |
| else: | |
| self.net = tvm.wide_resnet50_2(pretrained=pretrained) | |
| self.high_res = high_res | |
| self.scale_factor = 1 if not high_res else 1.5 | |
| def forward(self, x): | |
| net = self.net | |
| feats = {1: x} | |
| sf = self.scale_factor | |
| if self.high_res: | |
| x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic") | |
| x = net.conv1(x) | |
| x = net.bn1(x) | |
| x = net.relu(x) | |
| feats[2] = ( | |
| x | |
| if not self.high_res | |
| else F.interpolate( | |
| x, scale_factor=1 / sf, align_corners=False, mode="bilinear" | |
| ) | |
| ) | |
| x = net.maxpool(x) | |
| x = net.layer1(x) | |
| feats[4] = ( | |
| x | |
| if not self.high_res | |
| else F.interpolate( | |
| x, scale_factor=1 / sf, align_corners=False, mode="bilinear" | |
| ) | |
| ) | |
| x = net.layer2(x) | |
| feats[8] = ( | |
| x | |
| if not self.high_res | |
| else F.interpolate( | |
| x, scale_factor=1 / sf, align_corners=False, mode="bilinear" | |
| ) | |
| ) | |
| x = net.layer3(x) | |
| feats[16] = ( | |
| x | |
| if not self.high_res | |
| else F.interpolate( | |
| x, scale_factor=1 / sf, align_corners=False, mode="bilinear" | |
| ) | |
| ) | |
| x = net.layer4(x) | |
| feats[32] = ( | |
| x | |
| if not self.high_res | |
| else F.interpolate( | |
| x, scale_factor=1 / sf, align_corners=False, mode="bilinear" | |
| ) | |
| ) | |
| return feats | |
| def train(self, mode=True): | |
| super().train(mode) | |
| for m in self.modules(): | |
| if isinstance(m, nn.BatchNorm2d): | |
| m.eval() | |
| pass | |