Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from torchvision import transforms | |
| from PIL import Image | |
| import numpy as np | |
| np.int = int | |
| model = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True, | |
| **{'topN': 6, 'device':'cpu', 'num_classes': 200}) | |
| model.eval() | |
| def classify_bird(img): | |
| transform_test = transforms.Compose([ | |
| transforms.Resize((600, 600), Image.BILINEAR), | |
| transforms.CenterCrop((448, 448)), | |
| # transforms.RandomHorizontalFlip(), # only if train | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
| ]) | |
| scaled_img = transform_test(img) | |
| torch_images = scaled_img.unsqueeze(0) | |
| with torch.no_grad(): | |
| top_n_coordinates, concat_out, raw_logits, concat_logits, part_logits, top_n_index, top_n_prob = model(torch_images) | |
| probs = torch.softmax(concat_logits, 1)[0] | |
| prob_dict = {bird_cls: float(prob) for bird_cls, prob in zip(model.bird_classes, probs)} | |
| return prob_dict | |
| image_component = gr.Image(type="pil", label="Bird Image") | |
| label_component = gr.Label(label="Classification result", num_top_classes=3) | |
| description = """ | |
| ## About π€ | |
| Tutorial for deploying a gradio app on huggingface. This was done during a [livestream](https://youtube.com/live/bN9WTxzLBRE) on YouTube. | |
| ## Links π | |
| π YouTube Livestream: https://youtube.com/live/bN9WTxzLBRE\n | |
| π Torchvision Model: https://pytorch.org/hub/nicolalandro_ntsnet-cub200_ntsnet/\n | |
| π Paper: http://artelab.dista.uninsubria.it/res/research/papers/2019/2019-IVCNZ-Nawaz-Birds.pdf\n | |
| """ | |
| title = "Bird Classifier π£" | |
| demo = gr.Interface(fn=classify_bird, inputs=image_component, outputs=label_component, description=description, title=title) | |
| demo.launch() | |