Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python | |
| from __future__ import annotations | |
| import os | |
| import random | |
| from typing import Tuple, Optional | |
| import gradio as gr | |
| from huggingface_hub import HfApi | |
| from inf import InferencePipeline | |
| SAMPLE_MODEL_IDS = [ | |
| 'lora-library/B-LoRA-teddybear', | |
| 'lora-library/B-LoRA-bull', | |
| 'lora-library/B-LoRA-wolf_plushie', | |
| 'lora-library/B-LoRA-pen_sketch', | |
| 'lora-library/B-LoRA-cartoon_line', | |
| 'lora-library/B-LoRA-child', | |
| 'lora-library/B-LoRA-vase', | |
| 'lora-library/B-LoRA-scary_mug', | |
| 'lora-library/B-LoRA-statue', | |
| 'lora-library/B-LoRA-colorful_teapot', | |
| 'lora-library/B-LoRA-grey_sloth_plushie', | |
| 'lora-library/B-LoRA-teapot', | |
| 'lora-library/B-LoRA-backpack_dog', | |
| 'lora-library/B-LoRA-buddha', | |
| 'lora-library/B-LoRA-dog6', | |
| 'lora-library/B-LoRA-poop_emoji', | |
| 'lora-library/B-LoRA-pot', | |
| 'lora-library/B-LoRA-fat_bird', | |
| 'lora-library/B-LoRA-elephant', | |
| 'lora-library/B-LoRA-metal_bird', | |
| 'lora-library/B-LoRA-cat', | |
| 'lora-library/B-LoRA-dog2', | |
| 'lora-library/B-LoRA-drawing1', | |
| 'lora-library/B-LoRA-village_oil', | |
| 'lora-library/B-LoRA-watercolor', | |
| 'lora-library/B-LoRA-house_3d', | |
| 'lora-library/B-LoRA-ink_sketch', | |
| 'lora-library/B-LoRA-drawing3', | |
| 'lora-library/B-LoRA-crayon_drawing', | |
| 'lora-library/B-LoRA-kiss', | |
| 'lora-library/B-LoRA-drawing4', | |
| 'lora-library/B-LoRA-working_cartoon', | |
| 'lora-library/B-LoRA-painting', | |
| 'lora-library/B-LoRA-drawing2' | |
| 'lora-library/B-LoRA-multi-dog2', | |
| ] | |
| css = """ | |
| .gradio-container { | |
| max-width: 900px !important; | |
| } | |
| #title { | |
| text-align: center; | |
| } | |
| #title h1 { | |
| font-size: 250%; | |
| } | |
| .lora-title { | |
| background-image: linear-gradient(to right, #314755 0%, #26a0da 51%, #314755 100%); | |
| text-align: center; | |
| border-radius: 10px; | |
| display: block; | |
| } | |
| .lora-title h2 { | |
| color: white !important; | |
| } | |
| .gr-image { | |
| width: 256px; | |
| height: 256px; | |
| object-fit: contain; | |
| margin: auto; | |
| } | |
| .res-image { | |
| object-fit: contain; | |
| margin: auto; | |
| } | |
| .lora-column { | |
| border: none; | |
| background: none; | |
| } | |
| .gr-row { | |
| align-items: center; | |
| justify-content: center; | |
| margin-top: 5px; | |
| } | |
| .svelte-iyf88w { | |
| background: none; | |
| } | |
| """ | |
| def get_choices(hf_token): | |
| api = HfApi(token=hf_token) | |
| choices = [ | |
| info.modelId for info in api.list_models(author='lora-library') | |
| ] | |
| models_list = ['None'] + SAMPLE_MODEL_IDS + choices | |
| return models_list | |
| def get_image_from_card(card, model_id) -> Optional[str]: | |
| try: | |
| card_path = f"https://huggingface.co/{model_id}/resolve/main/" | |
| widget = card.data.get('widget') | |
| if widget is not None or len(widget) > 0: | |
| output = widget[0].get('output') | |
| if output is not None: | |
| url = output.get('url') | |
| if url is not None: | |
| return card_path + url | |
| return None | |
| except Exception: | |
| return None | |
| def demo_init(): | |
| try: | |
| choices = get_choices(app.hf_token) | |
| content_blora = random.choice(SAMPLE_MODEL_IDS) | |
| style_blora = random.choice(SAMPLE_MODEL_IDS) | |
| content_blora_prompt, content_blora_image = app.load_model_info(content_blora) | |
| style_blora_prompt, style_blora_image = app.load_model_info(style_blora) | |
| content_lora_model_id = gr.update(choices=choices, value=content_blora) | |
| content_prompt = gr.update(value=content_blora_prompt) | |
| content_image = gr.update(value=content_blora_image) | |
| style_lora_model_id = gr.update(choices=choices, value=style_blora) | |
| style_prompt = gr.update(value=style_blora_prompt) | |
| style_image = gr.update(value=style_blora_image) | |
| prompt = gr.update( | |
| value=f'{content_blora_prompt} in {style_blora_prompt[0].lower() + style_blora_prompt[1:]} style') | |
| return content_lora_model_id, content_prompt, content_image, style_lora_model_id, style_prompt, style_image, prompt | |
| except Exception as e: | |
| raise type(e)(f'failed to demo_init, due to: {e}') | |
| def toggle_column(is_checked): | |
| try: | |
| return 'None' if is_checked else random.choice(SAMPLE_MODEL_IDS) | |
| except Exception as e: | |
| raise type(e)(f'failed to toggle_column, due to: {e}') | |
| def handle_prompt_change(content_blora_prompt, style_blora_prompt) -> str: | |
| try: | |
| if content_blora_prompt and style_blora_prompt: | |
| return f'{content_blora_prompt} in {style_blora_prompt[0].lower() + style_blora_prompt[1:]} style' | |
| if content_blora_prompt: | |
| return content_blora_prompt | |
| if style_blora_prompt: | |
| return f'A dog in {style_blora_prompt[0].lower() + style_blora_prompt[1:]} style' | |
| return '' | |
| except Exception as e: | |
| raise type(e)(f'failed to handle_prompt_change, due to: {e}') | |
| class InferenceUtil: | |
| def __init__(self, hf_token: str | None): | |
| self.hf_token = hf_token | |
| def load_model_info(self, lora_model_id: str) -> Tuple[str, Optional[str]]: | |
| try: | |
| try: | |
| card = InferencePipeline.get_model_card(lora_model_id, | |
| self.hf_token) | |
| except Exception: | |
| return '', None | |
| instance_prompt = getattr(card.data, 'instance_prompt', '') | |
| image_url = get_image_from_card(card, lora_model_id) | |
| return instance_prompt, image_url | |
| except Exception as e: | |
| raise type(e)(f'failed to load_model_info, due to: {e}') | |
| def update_model_info(self, model_source: str): | |
| try: | |
| if model_source == 'None': | |
| return '', None | |
| else: | |
| model_info = self.load_model_info(model_source) | |
| new_prompt, new_image = model_info[0], model_info[1] | |
| return new_prompt, new_image | |
| except Exception as e: | |
| raise type(e)(f'failed to update_model_info, due to: {e}') | |
| hf_token = os.getenv('HF_TOKEN') | |
| pipe = InferencePipeline(hf_token) | |
| app = InferenceUtil(hf_token) | |
| with gr.Blocks(css=css) as demo: | |
| title = gr.HTML( | |
| '''<h1>Implicit Style-Content Separation using B-LoRA</h1> | |
| <p>This is a demo for our <a href="https://arxiv.org/abs/2403.14572">paper</a>: <b>''Implicit Style-Content Separation using B-LoRA''</b>. | |
| <br> | |
| Project page and code is available <a href="https://b-lora.github.io/B-LoRA/">here</a>.</p> | |
| Select your favorite style and content components from the list. (prefixed with <strong>`B-LoRA-`<strong>) | |
| ''', | |
| elem_id="title" | |
| ) | |
| with gr.Row(elem_classes="gr-row"): | |
| with gr.Column(): | |
| with gr.Group(elem_classes="lora-column"): | |
| content_sub_title = gr.HTML('''<h2>Content B-LoRA</h2>''', elem_classes="lora-title") | |
| content_checkbox = gr.Checkbox(label='Use Content Only', value=False) | |
| content_lora_model_id = gr.Dropdown(label='Model ID', choices=[]) | |
| content_prompt = gr.Text(label='Content instance prompt', interactive=False, max_lines=1) | |
| content_image = gr.Image(label='Content Image', elem_classes="gr-image") | |
| with gr.Column(): | |
| with gr.Group(elem_classes="lora-column"): | |
| style_sub_title = gr.HTML('''<h2>Style B-LoRA</h2>''', elem_classes="lora-title") | |
| style_checkbox = gr.Checkbox(label='Use Style Only', value=False) | |
| style_lora_model_id = gr.Dropdown(label='Model ID', choices=[]) | |
| style_prompt = gr.Text(label='Style instance prompt', interactive=False, max_lines=1) | |
| style_image = gr.Image(label='Style Image', elem_classes="gr-image") | |
| with gr.Row(elem_classes="gr-row"): | |
| with gr.Column(): | |
| with gr.Group(): | |
| prompt = gr.Textbox( | |
| label='Prompt', | |
| max_lines=1, | |
| placeholder='Example: "A [c] in [s] style"' | |
| ) | |
| result = gr.Gallery(label='Result', elem_classes="res-image") | |
| with gr.Accordion('Other Parameters', open=False, elem_classes="gr-accordion"): | |
| content_alpha = gr.Slider(label='Content B-LoRA alpha', | |
| minimum=0, | |
| maximum=2, | |
| step=0.05, | |
| value=1) | |
| style_alpha = gr.Slider(label='Style B-LoRA alpha', | |
| minimum=0, | |
| maximum=2, | |
| step=0.05, | |
| value=1) | |
| seed = gr.Slider(label='Seed', | |
| minimum=0, | |
| maximum=100000, | |
| step=1, | |
| value=8888) | |
| num_steps = gr.Slider(label='Number of Steps', | |
| minimum=0, | |
| maximum=100, | |
| step=1, | |
| value=40) | |
| guidance_scale = gr.Slider(label='CFG Scale', | |
| minimum=0, | |
| maximum=50, | |
| step=0.1, | |
| value=7.5) | |
| num_images_per_prompt = gr.Slider(label='Number of Images per Prompt', | |
| minimum=1, | |
| maximum=4, | |
| step=1, | |
| value=2) | |
| run_button = gr.Button('Generate') | |
| demo.load(demo_init, inputs=[], | |
| outputs=[content_lora_model_id, content_prompt, content_image, style_lora_model_id, style_prompt, | |
| style_image, prompt], queue=False, show_progress="hidden") | |
| content_lora_model_id.change( | |
| fn=app.update_model_info, | |
| inputs=content_lora_model_id, | |
| outputs=[ | |
| content_prompt, | |
| content_image, | |
| ]) | |
| style_lora_model_id.change( | |
| fn=app.update_model_info, | |
| inputs=style_lora_model_id, | |
| outputs=[ | |
| style_prompt, | |
| style_image, | |
| ]) | |
| style_prompt.change( | |
| fn=handle_prompt_change, | |
| inputs=[content_prompt, style_prompt], | |
| outputs=prompt, | |
| ) | |
| content_prompt.change( | |
| fn=handle_prompt_change, | |
| inputs=[content_prompt, style_prompt], | |
| outputs=prompt, | |
| ) | |
| content_checkbox.change(toggle_column, inputs=[content_checkbox], | |
| outputs=[style_lora_model_id]) | |
| style_checkbox.change(toggle_column, inputs=[style_checkbox], | |
| outputs=[content_lora_model_id]) | |
| inputs = [ | |
| content_lora_model_id, | |
| style_lora_model_id, | |
| prompt, | |
| content_alpha, | |
| style_alpha, | |
| seed, | |
| num_steps, | |
| guidance_scale, | |
| num_images_per_prompt | |
| ] | |
| prompt.submit(fn=pipe.run, inputs=inputs, outputs=result) | |
| run_button.click(fn=pipe.run, inputs=inputs, outputs=result) | |
| demo.queue(max_size=10).launch(share=False) | |