Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from PIL import Image | |
| import shutil | |
| import pickle | |
| import random | |
| import json | |
| import os | |
| if not os.path.exists("./data/pacs/"): | |
| shutil.unpack_archive("./data/pacs.zip", './data/', 'zip') | |
| METHODS = { | |
| "Textual Inversion (LDM)": "textualinversion_ldm", | |
| "Textual Inversion (Stable Diffusion)": "none_with_emb_without_multires", | |
| "DreamBooth": "unet_without_emb_without_multires", | |
| "Custom Diffusion": "kv_with_emb_without_multires", | |
| } | |
| for method in list(METHODS.values()): | |
| if not os.path.exists(f"./data/imagenet/images/{method}"): | |
| shutil.unpack_archive(f"./data/imagenet/images/{method}.zip", f"./", 'zip') | |
| if not os.path.exists(f"./data/imagenet/compositions/images/{method}"): | |
| shutil.unpack_archive(f"./data/imagenet/compositions/images/{method}.zip", f"./", 'zip') | |
| method="original" | |
| if not os.path.exists(f"./data/imagenet/images/{method}"): | |
| shutil.unpack_archive(f"./data/imagenet/images/{method}.zip", f"./", 'zip') | |
| print("Ready to go") | |
| CONCEPTS = { | |
| "Art Painting": "art_painting", | |
| "Cartoon": "cartoon", | |
| "Photo": "photo", | |
| "Sketch": "sketch", | |
| } | |
| DOMAINS = ["art_painting", "cartoon", "photo", "sketch"] | |
| with open("./data/imagenet/imagenet_mapping.pkl", "rb") as h: | |
| imagenet_mapping = pickle.load(h) | |
| OBJECTS = [] | |
| for k,v in imagenet_mapping.items(): | |
| CONCEPTS[f"{k}:{v}"] = k | |
| OBJECTS.append(f"{k}:{v}") | |
| def get_domains(method, concept): | |
| gen_cls=random.choice(os.listdir(os.path.join('./data/pacs', method, concept))) | |
| fname=random.choice(os.listdir(os.path.join('./data/pacs', method, concept, gen_cls))) | |
| gen_img = Image.open(os.path.join('./data/pacs', method, concept, gen_cls, fname)).resize((128, 128)) | |
| ref_images = [] | |
| for i in range(3): | |
| cls=random.choice(os.listdir(os.path.join('./data/pacs', 'original', concept))) | |
| fname=random.choice(os.listdir(os.path.join('./data/pacs', 'original', concept, cls))) | |
| img = Image.open(os.path.join('./data/pacs', 'original', concept, cls, fname)).resize((128, 128)) | |
| ref_images.append(img) | |
| return gen_img, f"a photo of {gen_cls} in the style of {concept}", ref_images | |
| def get_objects(method, concept, evaluation): | |
| if evaluation=="Concept Alignment": | |
| gen_cls = "" | |
| if "ldm" in method: | |
| gen_cls="samples" | |
| fname=random.choice(os.listdir(os.path.join('./data/imagenet/images', method, concept, gen_cls))) | |
| gen_img = Image.open(os.path.join('./data/imagenet/images', method, concept, gen_cls, fname)).resize((128, 128)) | |
| ref_images = [] | |
| for i in range(3): | |
| fname=random.choice(os.listdir(os.path.join('./data/imagenet/images', 'original', concept))) | |
| img = Image.open(os.path.join('./data/imagenet/images', 'original', concept, fname)).resize((128, 128)) | |
| ref_images.append(img) | |
| return gen_img, f"a photo of **{imagenet_mapping[concept]}**", ref_images | |
| else: | |
| gen_cls = "" | |
| if "ldm" in method: | |
| gen_cls="samples" | |
| with open(f"./data/imagenet/compositions/prompts/{concept}.json", "r") as h: | |
| prompts = json.load(h) | |
| fname=random.choice(os.listdir(os.path.join('./data/imagenet/compositions/images', method, concept, gen_cls))) | |
| gen_img = Image.open(os.path.join('./data/imagenet/compositions/images', method, concept, gen_cls, fname)).resize((128, 128)) | |
| idx = int(fname.split("_")[0]) | |
| caption = prompts[idx]["caption"].replace(prompts[idx]["entity"], f"**{prompts[idx]['entity']}**") | |
| ref_images = [] | |
| for i in range(3): | |
| fname=random.choice(os.listdir(os.path.join('./data/imagenet/images', 'original', concept))) | |
| img = Image.open(os.path.join('./data/imagenet/images', 'original', concept, fname)).resize((128, 128)) | |
| ref_images.append(img) | |
| return gen_img, caption, ref_images | |
| def get_images(method, concept, evaluation): | |
| method = METHODS[method] | |
| concept = CONCEPTS[concept] | |
| if concept in DOMAINS: | |
| images, captions, ref_images = get_domains(method, concept) | |
| return images, captions, ref_images | |
| elif concept in list(imagenet_mapping.keys()): | |
| images, captions, ref_images = get_objects(method, concept, evaluation) | |
| return images, captions, ref_images | |
| else: | |
| return | |
| css=''' | |
| #image_upload{min-height:4px} | |
| #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{max-height: 5} | |
| ''' | |
| image_blocks = gr.Blocks(css=css) | |
| with image_blocks as demo: | |
| # with gr.Blocks() as demo: | |
| gr.Markdown("<h1 style='text-align: center;'>ConceptBed Benchmark Explorer</h1>") | |
| gr.Markdown("<h1 style='text-align: center;'><a href='https://conceptbed.github.io'>Project Page</a> | <a href='https://arxiv.org/abs/2306.04695'>Paper</a> </h1>") | |
| gr.Markdown(""" | |
| ## How to interpret results: | |
| 1. The shown three images are reference concept images learned by the diffusion model. | |
| 2. The output target concept image is generated by Stable Diffusion using selected methodologies. | |
| 3. The output text indicates the prompt used to generate the image. | |
| # """) | |
| gr.Markdown(""" | |
| ## Types of evaluations: | |
| 1. Concept Alignment: available for all concepts | |
| 2. Compositional Reasoning: available for all concepts except -- Art Painting, Cartoon, Sketch, Photo | |
| # """) | |
| gr.Markdown(""" | |
| ### For further details on the ConceptBed benchmark, please refer to the paper at: <a href="https://arxiv.org/abs/2306.04695">https://arxiv.org/abs/2306.04695</a> | |
| # """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| methods1 = gr.Dropdown( | |
| list(METHODS.keys()), | |
| label="Concept Learner", | |
| info="Select a concept learning strategy." | |
| ) | |
| concept1 = gr.Dropdown( | |
| list(CONCEPTS.keys()), | |
| label="Concept", | |
| info="Select a concept." | |
| ) | |
| evaluation1 = gr.Dropdown( | |
| ["Concept Alignment", "Compositional Reasoning"], | |
| label="Evaluation Type", | |
| info="Select the evaluation type." | |
| ) | |
| gallery1 = gr.Gallery( | |
| label="Reference images", | |
| show_label=False, | |
| elem_id="gallery", | |
| ).style( | |
| columns=[3], rows=[1], height="200px" | |
| ) | |
| # image1 = gr.Gallery( | |
| # label="Reference images", | |
| # show_label=False, | |
| # elem_id="gallery", | |
| # ).style( | |
| # columns=[1], rows=[1], height="200px" | |
| # ) | |
| image1 = gr.Image()#.style(height="200px", width="200px") | |
| text1 = gr.Textbox(label="Caption used to generate above image") | |
| btn1 = gr.Button(value="Get Image", full_width=False) | |
| with gr.Column(): | |
| methods2 = gr.Dropdown( | |
| list(METHODS.keys()), | |
| label="Concept Learner", | |
| info="Select a concept learning strategy." | |
| ) | |
| concept2 = gr.Dropdown( | |
| list(CONCEPTS.keys()), | |
| label="Concept", | |
| info="Select a concept." | |
| ) | |
| evaluation2 = gr.Dropdown( | |
| ["Concept Alignment", "Compositional Reasoning"], | |
| label="Evaluation Type", | |
| info="Select the evaluation type." | |
| ) | |
| gallery2 = gr.Gallery( | |
| label="Reference images", | |
| show_label=False, | |
| elem_id="gallery", | |
| ).style( | |
| columns=[3], rows=[1], height="200px" | |
| ) | |
| image2 = gr.Image(elem_id="image_upload") | |
| text2 = gr.Textbox(label="Caption used to generate above image") | |
| btn2 = gr.Button(value="Get Image", full_width=False) | |
| btn1.click(get_images, inputs=[methods1, concept1, evaluation1], outputs=[image1, text1, gallery1]) | |
| btn2.click(get_images, inputs=[methods2, concept2, evaluation2], outputs=[image2, text2, gallery2]) | |
| with gr.Accordion(label="Notes", open=False): | |
| gr.HTML( | |
| """<div class="acknowledgments"> | |
| <p><h4>Generated Images:</h4> | |
| As ConceptBed evaluations required training of 1000+ models (one for each concept), it is impossible to host a live demo. | |
| Therefore, we generate 200,000+ images and randomly select a few images for this demo. | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |