Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| from torchvision import transforms, models | |
| class ActionClassifier(nn.Module): | |
| def __init__(self, train_last_nlayer, hidden_size, dropout, ntargets): | |
| super().__init__() | |
| resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT, progress=True) | |
| modules = list(resnet.children())[:-1] # delete last layer | |
| self.resnet = nn.Sequential(*modules) | |
| for param in self.resnet[:-train_last_nlayer].parameters(): | |
| param.requires_grad = False | |
| self.fc = nn.Sequential( | |
| nn.Flatten(), | |
| nn.BatchNorm1d(resnet.fc.in_features), | |
| nn.Dropout(dropout), | |
| nn.Linear(resnet.fc.in_features, hidden_size), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(hidden_size), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_size, ntargets), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| x = self.resnet(x) | |
| x = self.fc(x) | |
| return x | |
| def get_transform(): | |
| transform = transforms.Compose([ | |
| transforms.Resize([224, 244]), | |
| models.ResNet50_Weights.DEFAULT.transforms() | |
| ]) | |
| return transform | |
| # def get_transform(): | |
| # transform = transforms.Compose([ | |
| # transforms.Resize([224, 244]), | |
| # transforms.ToTensor(), | |
| # # std multiply by 255 to convert img of [0, 255] | |
| # # to img of [0, 1] | |
| # transforms.Normalize((0.485, 0.456, 0.406), | |
| # (0.229*255, 0.224*255, 0.225*255))] | |
| # ) | |
| # return transform | |
| def get_model(): | |
| model = ActionClassifier(0, 512, 0.2, 15) | |
| model.load_state_dict(torch.load('./model_weights.pth', map_location=torch.device('cpu'))) | |
| return model | |
| def get_class(index): | |
| ind2cat = [ | |
| 'calling', | |
| 'clapping', | |
| 'cycling', | |
| 'dancing', | |
| 'drinking', | |
| 'eating', | |
| 'fighting', | |
| 'hugging', | |
| 'laughing', | |
| 'listening_to_music', | |
| 'running', | |
| 'sitting', | |
| 'sleeping', | |
| 'texting', | |
| 'using_laptop' | |
| ] | |
| return ind2cat[index] | |
| # img = Image.open('./inputs/Image_102.jpg').convert('RGB') | |
| # #print(transform(img)) | |
| # img = transform(img) | |
| # img = img.unsqueeze(dim=0) | |
| # print(img.shape) | |
| # model.eval() | |
| # with torch.no_grad(): | |
| # out = model(img) | |
| # out = nn.Softmax()(out).squeeze() | |
| # print(out.shape) | |
| # res = torch.argmax(out) | |
| # print(ind2cat[res]) | |