Spaces:
Build error
Build error
| from __future__ import print_function | |
| import torch | |
| import process_stylization | |
| from photo_wct import PhotoWCT | |
| import gradio as gr | |
| from datetime import datetime | |
| # Load model | |
| model_path = './models/photo_wct.pth' | |
| p_wct = PhotoWCT() | |
| p_wct.load_state_dict(torch.load(model_path)) | |
| def run(content_img, style_img, cuda, post_processing, fast): | |
| now = datetime.now() | |
| dt_string = now.strftime("%d/%m/%Y %H:%M:%S") | |
| print("[TimeStamp] {}".format(dt_string)) | |
| if fast == 0: | |
| from photo_gif import GIFSmoothing | |
| p_pro = GIFSmoothing(r=35, eps=0.001) | |
| else: | |
| from photo_smooth import Propagator | |
| p_pro = Propagator() | |
| if cuda: | |
| p_wct.cuda(0) | |
| else: | |
| p_wct.to('cpu') | |
| output_img = process_stylization.stylization_gradio( | |
| stylization_module=p_wct, | |
| smoothing_module=p_pro, | |
| content_image=content_img, | |
| style_image=style_img, | |
| cuda=cuda, | |
| post_processing=post_processing | |
| ) | |
| return output_img | |
| if __name__ == '__main__': | |
| style = gr.Interface( | |
| fn=run, | |
| inputs=[ | |
| gr.Image(label='Content Image'), | |
| gr.Image(label='Stylize Image'), | |
| gr.Checkbox(value=True, label='Use CUDA'), | |
| gr.Checkbox(value=True, label='Post Processing'), | |
| gr.Radio(choices=["Guided Image Filtering (Fast)", "Photorealisitic Smoothing (Slow)"], value="Guided Image Filtering (Fast)", type="index", label="Algorithm", interactive=False), | |
| ], | |
| outputs=[gr.Image( | |
| type="pil", | |
| label="Result"), | |
| ] | |
| ) | |
| style.queue() | |
| style.launch() |