Spaces:
Build error
Build error
| import torch | |
| from torch import nn | |
| from torchvision import datasets | |
| from torchvision.transforms import ToTensor | |
| # Define model | |
| class NeuralNetwork(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.flatten = nn.Flatten() | |
| self.linear_relu_stack = nn.Sequential( | |
| nn.Linear(28*28, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, 10) | |
| ) | |
| def forward(self, x): | |
| x = self.flatten(x) | |
| logits = self.linear_relu_stack(x) | |
| return logits | |
| model = NeuralNetwork() | |
| model.load_state_dict(torch.load("model_mnist_mlp.pth")) | |
| model.eval() | |
| import gradio as gr | |
| from torchvision import transforms | |
| def predict(image): | |
| tsr_image = transforms.ToTensor()(image) | |
| with torch.no_grad(): | |
| pred = model(tsr_image) | |
| prob = torch.nn.functional.softmax(pred[0], dim=0) | |
| confidences = {i: float(prob[i]) for i in range(10)} | |
| return confidences | |
| with gr.Blocks(css=".gradio-container {background:lightyellow;color:red;}", title="γγΉγ" | |
| ) as demo: | |
| gr.HTML('<div style="font-size:12pt; text-align:center; color:yellow;"MNIST ει‘ε¨</div>') | |
| with gr.Row(): | |
| with gr.Tab("γγ£γ³γγΉ"): | |
| input_image1 = gr.Image(label="η»εε ₯ε", source="canvas", type="pil", image_mode="L", shape=(28,28), invert_colors=True) | |
| send_btn1 = gr.Button("δΊζΈ¬γγ") | |
| with gr.Tab("η»εγγ‘γ€γ«"): | |
| input_image2 = gr.Image(label="η»εε ₯ε", type="pil", image_mode="L", shape=(28, 28), invert_colors=True) | |
| send_btn2 = gr.Button("δΊζΈ¬γγ") | |
| gr.Examples(['examples/example02.png', 'examples/example04.png'], inputs=input_image2) | |
| output_label=gr.Label(label="δΊζΈ¬η’Ίη", num_top_classes=5) | |
| send_btn1.click(fn=predict, inputs=input_image1, outputs=output_label) | |
| send_btn2.click(fn=predict, inputs=input_image2, outputs=output_label) | |
| # demo.queue(concurrency_count=3) | |
| demo.launch() | |
| ### EOF ### |