Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from io import StringIO | |
| import requests | |
| import torch | |
| from torchvision.models.inception import inception_v3 | |
| import matplotlib.pyplot as plt | |
| from skimage.transform import resize | |
| def load_stuff(): | |
| model = inception_v3(pretrained=True, # load existing weights | |
| transform_input=True, # preprocess input image the same way as in training | |
| ) | |
| model.aux_logits = False # don't predict intermediate logits (yellow layers at the bottom) | |
| model.train(False) | |
| LABELS_URL = 'https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json' | |
| labels = {i: c for i, c in enumerate(requests.get(LABELS_URL).json())} | |
| return model, labels | |
| model, labels = load_stuff() | |
| def transform_input(img): | |
| return torch.as_tensor(img.reshape([1, 299, 299, 3]).transpose([0, 3, 1, 2]), dtype=torch.float32) | |
| def predict(img): | |
| img = transform_input(img) | |
| probs = torch.nn.functional.softmax(model(img), dim=-1) | |
| probs = probs.data.numpy() | |
| top_ix = probs.ravel().argsort()[-1:-10:-1] | |
| s = 'top-10 classes are: \n\n [prob : class label]\n\n' | |
| for l in top_ix: | |
| s = s + '%.4f :\t%s' % (probs.ravel()[l], labels[l].split(',')[0]) + '\n\n' | |
| return s | |
| st.markdown("### Hello dude!") | |
| uploaded_file = st.file_uploader("Choose a file") | |
| if uploaded_file is not None: | |
| # To read file as bytes: | |
| bytes_data = uploaded_file.getvalue() | |
| with open('tmp', 'wb')as f: | |
| f.write(bytes_data) | |
| img = resize(plt.imread('tmp'), (299, 299))[..., :3] | |
| top_classes = predict(img) | |
| st.markdown(top_classes) | |
| st.image('tmp') | |