Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn, optim | |
| from torchvision import transforms, models | |
| #from torch_snippets import * | |
| #from torch.utils.data import DataLoader, Dataset | |
| #from torchsummary import summary | |
| #import seaborn as sns | |
| #import matplotlib.pyplot as plt | |
| #from sklearn.model_selection import train_test_split | |
| from PIL import Image | |
| #import numpy as np | |
| #import cv2 | |
| #from glob import glob | |
| #import pandas as pd | |
| import numpy as np | |
| #device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| class ActionClassifier(nn.Module): | |
| def __init__(self, ntargets): | |
| super().__init__() | |
| resnet = models.resnet50(pretrained=True, progress=True) | |
| modules = list(resnet.children())[:-1] # delete last layer | |
| self.resnet = nn.Sequential(*modules) | |
| for param in self.resnet.parameters(): | |
| param.requires_grad = False | |
| self.fc = nn.Sequential( | |
| nn.Flatten(), | |
| nn.BatchNorm1d(resnet.fc.in_features), | |
| nn.Dropout(0.2), | |
| nn.Linear(resnet.fc.in_features, 256), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(256), | |
| nn.Dropout(0.2), | |
| nn.Linear(256, ntargets) | |
| ) | |
| def forward(self, x): | |
| x = self.resnet(x) | |
| x = self.fc(x) | |
| return x | |
| def get_transform(): | |
| transform = transforms.Compose([ | |
| transforms.Resize([224, 244]), | |
| transforms.ToTensor(), | |
| # std multiply by 255 to convert img of [0, 255] | |
| # to img of [0, 1] | |
| transforms.Normalize((0.485, 0.456, 0.406), | |
| (0.229*255, 0.224*255, 0.225*255))] | |
| ) | |
| return transform | |
| def get_model(): | |
| model = ActionClassifier(15) | |
| model.load_state_dict(torch.load('./classifier_weights.pth', map_location=torch.device('cpu'))) | |
| return model | |
| def get_class(index): | |
| ind2cat = [ | |
| 'calling', | |
| 'clapping', | |
| 'cycling', | |
| 'dancing', | |
| 'drinking', | |
| 'eating', | |
| 'fighting', | |
| 'hugging', | |
| 'laughing', | |
| 'listening_to_music', | |
| 'running', | |
| 'sitting', | |
| 'sleeping', | |
| 'texting', | |
| 'using_laptop' | |
| ] | |
| return ind2cat[index] | |
| # img = Image.open('./inputs/Image_102.jpg').convert('RGB') | |
| # #print(transform(img)) | |
| # img = transform(img) | |
| # img = img.unsqueeze(dim=0) | |
| # print(img.shape) | |
| # model.eval() | |
| # with torch.no_grad(): | |
| # out = model(img) | |
| # out = nn.Softmax()(out).squeeze() | |
| # print(out.shape) | |
| # res = torch.argmax(out) | |
| # print(ind2cat[res]) | |