sasha's picture
sasha HF Staff
adding temp dir for images
1b8ecf8
raw
history blame
6.23 kB
import gradio as gr
import random, os, shutil
from PIL import Image
import pandas as pd
import tempfile
def open_sd_ims(adj, group, seed):
if group != '':
if adj != '':
prompt=adj+'_'+group.replace(' ','_')
if os.path.isdir(prompt) == False:
shutil.unpack_archive('zipped_images/stablediffusion/'+ prompt.replace(' ', '_') +'.zip', prompt, 'zip')
else:
prompt=group
if os.path.isdir(prompt) == False:
shutil.unpack_archive('zipped_images/stablediffusion/'+ prompt.replace(' ', '_') +'.zip', prompt, 'zip')
imnames= os.listdir(prompt+'/Seed_'+ str(seed)+'/')
images = [(Image.open(prompt+'/Seed_'+ str(seed)+'/'+name)) for name in imnames]
return images[:9]
def open_ims(model, adj, group):
seed = 93109
with tempfile.TemporaryDirectory() as tmpdirname:
print('created temporary directory', tmpdirname)
if model == "Dall-E 2":
if group != '':
if adj != '':
prompt=adj+'_'+group.replace(' ','_')
if os.path.isdir(tmpdirname + '/' + model.replace(' ','').lower()+ '/'+ prompt) == False:
shutil.unpack_archive('zipped_images/'+ model.replace(' ','').lower()+ '/'+ prompt.replace(' ', '_') +'.zip', tmpdirname+ '/'+ model.replace(' ','').lower()+ '/'+ prompt, 'zip')
else:
prompt=group
if os.path.isdir(tmpdirname + '/' + model.replace(' ','').lower()+ '/'+ prompt) == False:
shutil.unpack_archive('zipped_images/' + model.replace(' ','').lower() + '/'+ prompt.replace(' ', '_') +'.zip', tmpdirname + '/' + model.replace(' ','').lower()+ '/' + prompt, 'zip')
imnames= os.listdir(tmpdirname + '/' + model.replace(' ','').lower()+ '/'+ prompt+'/')
images = [(Image.open(tmpdirname + '/' + model.replace(' ','').lower()+ '/'+ prompt+'/'+name)).convert("RGB") for name in imnames]
return images[:9]
else:
if group != '':
if adj != '':
prompt=adj+'_'+group.replace(' ','_')
if os.path.isdir(tmpdirname + '/' + model.replace(' ','').lower()+ '/'+ prompt) == False:
shutil.unpack_archive('zipped_images/'+ model.replace(' ','').lower()+ '/'+ prompt.replace(' ', '_') +'.zip', tmpdirname + '/' +model.replace(' ','').lower()+ '/'+ prompt, 'zip')
else:
prompt=group
if os.path.isdir(tmpdirname + '/' + model.replace(' ','').lower()+ '/'+ prompt) == False:
shutil.unpack_archive('zipped_images/' + model.replace(' ','').lower() + '/'+ prompt.replace(' ', '_') +'.zip', tmpdirname + '/' + model.replace(' ','').lower()+'/'+ prompt, 'zip')
imnames= os.listdir(tmpdirname + '/' + model.replace(' ','').lower()+ '/'+ prompt+'/'+'/Seed_'+ str(seed)+'/')
images = [(Image.open(tmpdirname + '/' + model.replace(' ','').lower()+ '/'+ prompt +'/'+'/Seed_'+ str(seed)+'/'+name)) for name in imnames]
return images[:9]
vowels = ["a","e","i","o","u"]
prompts = pd.read_csv('promptsadjectives.csv')
seeds = [46267, 48040, 51237, 54325, 60884, 64830, 67031, 72935, 92118, 93109]
m_adjectives = prompts['Masc-adj'].tolist()[:10]
f_adjectives = prompts['Fem-adj'].tolist()[:10]
adjectives = sorted(m_adjectives+f_adjectives)
#adjectives = ['attractive','strong']
adjectives.insert(0, '')
professions = sorted([p.lower() for p in prompts['Occupation-Noun'].tolist()])
models = ["Stable Diffusion 1.4", "Dall-E 2"]
with gr.Blocks() as demo:
gr.Markdown("# Stable Diffusion Explorer")
gr.Markdown("## Choose from the prompts below to explore how the [Stable Diffusion v1.4 model](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original) represents different professions and adjectives")
gr.Markdown("Some of the images for Dall-E 2 are missing -- we are still in the process of generating them! If you get an 'error', please pick another prompt.")
# seed_choice = gr.State(0)
# seed_choice = 93109
# print("Seed choice is: " + str(seed_choice))
with gr.Row():
with gr.Column():
model1 = gr.Dropdown(models, label = "Choose a model to compare results", value = models[0], interactive=True)
adj1 = gr.Dropdown(adjectives, label = "Choose a first adjective (or leave this blank!)", interactive=True)
choice1 = gr.Dropdown(professions, label = "Choose a first group", interactive=True)
# seed1= gr.Dropdown(seeds, label = "Choose a random seed to compare results", value = seeds[1], interactive=True)
images1 = gr.Gallery(label="Images").style(grid=[3], height="auto")
with gr.Column():
model2 = gr.Dropdown(models, label = "Choose a model to compare results", value = models[0], interactive=True)
adj2 = gr.Dropdown(adjectives, label = "Choose a second adjective (or leave this blank!)", interactive=True)
choice2 = gr.Dropdown(professions, label = "Choose a second group", interactive=True)
# seed2= gr.Dropdown(seeds, label = "Choose a random seed to compare results", value= seeds[1], interactive=True)
images2 = gr.Gallery(label="Images").style(grid=[3], height="auto")
gr.Markdown("### [Research](http://gender-decoder.katmatfield.com/static/documents/Gaucher-Friesen-Kay-JPSP-Gendered-Wording-in-Job-ads.pdf) has shown that \
certain words are considered more masculine- or feminine-coded based on how appealing job descriptions containing these words \
seemed to male and female research participants and to what extent the participants felt that they 'belonged' in that occupation.")
#demo.load(random_image, None, [images])
choice1.change(open_ims, [model1, adj1,choice1], [images1])
choice2.change(open_ims, [model2, adj2,choice2], [images2])
adj1.change(open_ims, [model1, adj1, choice1], [images1])
adj2.change(open_ims, [model2, adj2, choice2], [images2])
# seed1.change(open_ims, [adj1,choice1,seed1], [images1])
# seed2.change(open_ims, [adj2,choice2,seed2], [images2])
demo.launch(share=True)