Spaces:
Runtime error
Runtime error
| from gradio_imageslider import ImageSlider | |
| import functools | |
| import os | |
| import tempfile | |
| import diffusers | |
| import gradio as gr | |
| import imageio as imageio | |
| import numpy as np | |
| import spaces | |
| import torch as torch | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from pathlib import Path | |
| import gradio | |
| from gradio.utils import get_cache_folder | |
| from infer import lotus, lotus_video | |
| import transformers | |
| transformers.utils.move_cache() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def infer(path_input, seed): | |
| name_base, name_ext = os.path.splitext(os.path.basename(path_input)) | |
| output_g, output_d = lotus(path_input, 'normal', seed, device) | |
| if not os.path.exists("files/output"): | |
| os.makedirs("files/output") | |
| g_save_path = os.path.join("files/output", f"{name_base}_g{name_ext}") | |
| d_save_path = os.path.join("files/output", f"{name_base}_d{name_ext}") | |
| output_g.save(g_save_path) | |
| output_d.save(d_save_path) | |
| return [path_input, g_save_path], [path_input, d_save_path] | |
| def infer_video(path_input, seed): | |
| frames_g, frames_d = lotus_video(path_input, 'normal', seed, device) | |
| if not os.path.exists("files/output"): | |
| os.makedirs("files/output") | |
| name_base, _ = os.path.splitext(os.path.basename(path_input)) | |
| g_save_path = os.path.join("files/output", f"{name_base}_g.mp4") | |
| d_save_path = os.path.join("files/output", f"{name_base}_d.mp4") | |
| imageio.mimsave(g_save_path, frames_g) | |
| imageio.mimsave(d_save_path, frames_d) | |
| return [g_save_path, d_save_path] | |
| def run_demo_server(): | |
| infer_gpu = spaces.GPU(functools.partial(infer)) | |
| gradio_theme = gr.themes.Default() | |
| with gr.Blocks( | |
| theme=gradio_theme, | |
| title="LOTUS (Normal)", | |
| css=""" | |
| #download { | |
| height: 118px; | |
| } | |
| .slider .inner { | |
| width: 5px; | |
| background: #FFF; | |
| } | |
| .viewport { | |
| aspect-ratio: 4/3; | |
| } | |
| .tabs button.selected { | |
| font-size: 20px !important; | |
| color: crimson !important; | |
| } | |
| h1 { | |
| text-align: center; | |
| display: block; | |
| } | |
| h2 { | |
| text-align: center; | |
| display: block; | |
| } | |
| h3 { | |
| text-align: center; | |
| display: block; | |
| } | |
| .md_feedback li { | |
| margin-bottom: 0px !important; | |
| } | |
| """, | |
| head=""" | |
| <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script> | |
| <script> | |
| window.dataLayer = window.dataLayer || []; | |
| function gtag() {dataLayer.push(arguments);} | |
| gtag('js', new Date()); | |
| gtag('config', 'G-1FWSVCGZTG'); | |
| </script> | |
| """, | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # LOTUS: Diffusion-based Visual Foundation Model for High-quality Dense Prediction | |
| <p align="center"> | |
| <a title="Page" href="https://lotus3d.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
| <img src="https://img.shields.io/badge/Project-Website-pink?logo=googlechrome&logoColor=white"> | |
| </a> | |
| <a title="arXiv" href="https://arxiv.org/abs/2409.18124" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
| <img src="https://img.shields.io/badge/arXiv-Paper-b31b1b?logo=arxiv&logoColor=white"> | |
| </a> | |
| <a title="Github" href="https://github.com/EnVision-Research/Lotus" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
| <img src="https://img.shields.io/github/stars/EnVision-Research/Lotus?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars"> | |
| </a> | |
| <a title="Social" href="https://x.com/Jingheya/status/1839553365870784563" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
| <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social"> | |
| </a> | |
| <a title="Social" href="https://x.com/haodongli00/status/1839524569058582884" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
| <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social"> | |
| </a> | |
| <br> | |
| <strong>Please consider starring <span style="color: orange">★</span> the <a href="https://github.com/EnVision-Research/Lotus" target="_blank" rel="noopener noreferrer">GitHub Repo</a> if you find this useful!</strong> | |
| """ | |
| ) | |
| with gr.Tabs(elem_classes=["tabs"]): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image( | |
| label="Input Image", | |
| type="filepath", | |
| ) | |
| seed = gr.Number( | |
| label="Seed (only for Generative mode)", | |
| minimum=0, | |
| maximum=999999999, | |
| ) | |
| with gr.Row(): | |
| image_submit_btn = gr.Button( | |
| value="Predict Normal!", variant="primary" | |
| ) | |
| image_reset_btn = gr.Button(value="Reset") | |
| with gr.Column(): | |
| image_output_g = ImageSlider( | |
| label="Output (Generative)", | |
| type="filepath", | |
| interactive=False, | |
| elem_classes="slider", | |
| position=0.25, | |
| ) | |
| with gr.Row(): | |
| image_output_d = ImageSlider( | |
| label="Output (Discriminative)", | |
| type="filepath", | |
| interactive=False, | |
| elem_classes="slider", | |
| position=0.25, | |
| ) | |
| gr.Examples( | |
| fn=infer_gpu, | |
| examples=sorted([ | |
| [os.path.join("files", "images", name), 0] | |
| for name in os.listdir(os.path.join("files", "images")) | |
| ]), | |
| inputs=[image_input, seed], | |
| outputs=[image_output_g, image_output_d], | |
| cache_examples=False, | |
| ) | |
| ### Image | |
| image_submit_btn.click( | |
| fn=infer_gpu, | |
| inputs=[image_input, seed], | |
| outputs=[image_output_g, image_output_d], | |
| ) | |
| image_reset_btn.click( | |
| fn=lambda: (None, None, None), | |
| inputs=[], | |
| outputs=[image_output_g, image_output_d], | |
| queue=False, | |
| ) | |
| ### Server launch | |
| demo.queue( | |
| api_open=False, | |
| ).launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| ) | |
| def main(): | |
| os.system("pip freeze") | |
| if os.path.exists("files/output"): | |
| os.system("rm -rf files/output") | |
| run_demo_server() | |
| if __name__ == "__main__": | |
| main() | |