more layers
Browse files
app.py
CHANGED
|
@@ -16,9 +16,21 @@ def predict(img):
|
|
| 16 |
topk_values, topk_indices = torch.topk(output, 2) # Get the top 2 classes
|
| 17 |
return [str(k) for k in topk_indices[0].tolist()]
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
inputs=sp,
|
| 24 |
-
outputs=['label','label']).launch()
|
|
|
|
| 16 |
topk_values, topk_indices = torch.topk(output, 2) # Get the top 2 classes
|
| 17 |
return [str(k) for k in topk_indices[0].tolist()]
|
| 18 |
|
| 19 |
+
with gr.Blocks() as iface:
|
| 20 |
+
gr.Markdown("# MNIST + Gradio End to End")
|
| 21 |
+
gr.HTML("Shows end to end MNIST training with Gradio interface")
|
| 22 |
+
with gr.Row():
|
| 23 |
+
with gr.Column():
|
| 24 |
+
sp = gr.Sketchpad(shape=(28, 28))
|
| 25 |
+
with gr.Row():
|
| 26 |
+
with gr.Column():
|
| 27 |
+
pred_button = gr.Button("Predict")
|
| 28 |
+
with gr.Column():
|
| 29 |
+
clear = gr.Button("Clear")
|
| 30 |
+
with gr.Column():
|
| 31 |
+
label1 = gr.Label(label='1st Pred')
|
| 32 |
+
label2 = gr.Label(label='2nd Pred')
|
| 33 |
|
| 34 |
+
pred_button.click(predict, inputs=sp, outputs=[label1,label2])
|
| 35 |
+
clear.click(lambda: None, None, sp, queue=False)
|
| 36 |
+
iface.launch()
|
|
|
|
|
|
mnist.pth
CHANGED
|
Binary files a/mnist.pth and b/mnist.pth differ
|
|
|
model.py
CHANGED
|
@@ -5,11 +5,13 @@ class Net(nn.Module):
|
|
| 5 |
def __init__(self):
|
| 6 |
super(Net, self).__init__()
|
| 7 |
self.fc1 = nn.Linear(28*28, 128) # MNIST images are 28x28
|
| 8 |
-
self.fc2 = nn.Linear(128,
|
| 9 |
-
self.fc3 = nn.Linear(
|
|
|
|
| 10 |
|
| 11 |
def forward(self, x):
|
| 12 |
x = x.view(x.shape[0], -1) # Flatten the input
|
| 13 |
x = torch.relu(self.fc1(x))
|
| 14 |
x = torch.relu(self.fc2(x))
|
| 15 |
-
|
|
|
|
|
|
| 5 |
def __init__(self):
|
| 6 |
super(Net, self).__init__()
|
| 7 |
self.fc1 = nn.Linear(28*28, 128) # MNIST images are 28x28
|
| 8 |
+
self.fc2 = nn.Linear(128, 128)
|
| 9 |
+
self.fc3 = nn.Linear(128, 64)
|
| 10 |
+
self.fc4 = nn.Linear(64, 10) # There are 10 classes (0 through 9)
|
| 11 |
|
| 12 |
def forward(self, x):
|
| 13 |
x = x.view(x.shape[0], -1) # Flatten the input
|
| 14 |
x = torch.relu(self.fc1(x))
|
| 15 |
x = torch.relu(self.fc2(x))
|
| 16 |
+
x = torch.relu(self.fc3(x))
|
| 17 |
+
return self.fc4(x)
|
requirements.txt
CHANGED
|
@@ -1,5 +1,3 @@
|
|
| 1 |
-
gradio==3.29.0
|
| 2 |
-
numpy==1.23.5
|
| 3 |
Pillow==9.1.0
|
| 4 |
torch==2.0.1
|
| 5 |
torchvision==0.15.2
|
|
|
|
|
|
|
|
|
|
| 1 |
Pillow==9.1.0
|
| 2 |
torch==2.0.1
|
| 3 |
torchvision==0.15.2
|