Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from test import inference_img | |
| from models import * | |
| import numpy as np | |
| from PIL import Image | |
| device='cpu' | |
| model = StyleMatte() | |
| model = model.to(device) | |
| checkpoint = f"stylematte.pth" | |
| state_dict = torch.load(checkpoint, map_location=f'{device}') | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| def predict(inp): | |
| print("***********Inference****************") | |
| mask = inference_img(model, inp) | |
| inp_np = np.array(inp) | |
| fg = np.uint8((mask[:,:,None]*inp_np)) | |
| alpha_channel = (mask*255).astype(np.uint8) | |
| print(fg.max(), alpha_channel.max(), fg.shape, alpha_channel.shape) | |
| print("***********Inference finish****************") | |
| # print("***********MASK****************", inp_np.max(), mask.max()) | |
| fg = np.dstack((fg, alpha_channel)) | |
| fg_pil = Image.fromarray(fg, 'RGBA') | |
| return [mask, fg_pil] | |
| print("MODEL LOADED") | |
| print("************************************") | |
| iface = gr.Interface(fn=predict, | |
| inputs=gr.Image(type="numpy"), | |
| outputs=[gr.Image(type="numpy"),gr.Image(type="pil", image_mode='RGBA')], | |
| examples=["./logo.jpeg"]) | |
| print("****************Interface created******************") | |
| iface.launch() |