Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| from src.predict import predict_interval, load_torch_net | |
| def predict_gradio_canvas(x, net, device="cuda"): | |
| if x is None: | |
| return {0: 0} | |
| else: | |
| x = torch.from_numpy(x.reshape(1, 28, 28)).to(dtype=torch.float32, device=device) | |
| return predict_interval(x, net, device) | |
| def main(device="cuda"): | |
| net = load_torch_net("../checkpoints/pytorch/version_1.pt") | |
| gr.Interface(fn=lambda x: predict_gradio_canvas(x, net, device), | |
| inputs="sketchpad", | |
| outputs="label", | |
| live=True).launch() | |
| if __name__ == "__main__": | |
| main(device="cpu") | |