Spaces:
Build error
Build error
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from torchvision.transforms import ToTensor | |
| # Define model | |
| class ConvNet(nn.Module): | |
| def __init__(self): | |
| super(ConvNet, self).__init__() | |
| self.conv1 = nn.Conv2d(1, 32, kernel_size=5) | |
| self.conv2 = nn.Conv2d(32, 32, kernel_size=5) | |
| self.conv3 = nn.Conv2d(32,64, kernel_size=5) | |
| self.fc1 = nn.Linear(3*3*64, 256) | |
| self.fc2 = nn.Linear(256, 10) | |
| def forward(self, x): | |
| x = F.relu(self.conv1(x)) | |
| #x = F.dropout(x, p=0.5, training=self.training) | |
| x = F.relu(F.max_pool2d(self.conv2(x), 2)) | |
| x = F.dropout(x, p=0.5, training=self.training) | |
| x = F.relu(F.max_pool2d(self.conv3(x),2)) | |
| x = F.dropout(x, p=0.5, training=self.training) | |
| x = x.view(-1,3*3*64 ) | |
| x = F.relu(self.fc1(x)) | |
| x = F.dropout(x, training=self.training) | |
| logits = self.fc2(x) | |
| return logits | |
| model = ConvNet() | |
| model.load_state_dict( | |
| torch.load("weights/mnist_convnet_model.pth", | |
| map_location=torch.device('cpu')) | |
| ) | |
| model.eval() | |
| import gradio as gr | |
| from torchvision import transforms | |
| import os | |
| import glob | |
| examples_dir = './examples' | |
| example_files = glob.glob(os.path.join(examples_dir, '*.png')) | |
| def predict(image): | |
| tsr_image = transforms.ToTensor()(image) | |
| with torch.no_grad(): | |
| pred = model(tsr_image) | |
| prob = torch.nn.functional.softmax(pred[0], dim=0) | |
| confidences = {i: float(prob[i]) for i in range(10)} | |
| return confidences | |
| with gr.Blocks(css=".gradio-container {background:honeydew;}", title="MNIST Classification" | |
| ) as demo: | |
| gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">MNIST Classification</div>""") | |
| with gr.Row(): | |
| with gr.Tab("Canvas"): | |
| input_image1 = gr.Image(source="canvas", type="pil", image_mode="L", shape=(28,28), invert_colors=True) | |
| send_btn1 = gr.Button("Infer") | |
| with gr.Tab("Image file"): | |
| input_image2 = gr.Image(type="pil", image_mode="L", shape=(28, 28), invert_colors=True) | |
| send_btn2 = gr.Button("Infer") | |
| gr.Examples(example_files, inputs=input_image2) | |
| #gr.Examples(['examples/sample02.png', 'examples/sample04.png'], inputs=input_image2) | |
| output_label=gr.Label(label="Probabilities", num_top_classes=3) | |
| send_btn1.click(fn=predict, inputs=input_image1, outputs=output_label) | |
| send_btn2.click(fn=predict, inputs=input_image2, outputs=output_label) | |
| # demo.queue(concurrency_count=3) | |
| demo.launch() | |
| ### EOF ### |