Spaces:
Paused
Paused
| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| import re | |
| import os | |
| import requests | |
| from customization import customize_vae_decoder | |
| from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel, DDIMScheduler, EulerDiscreteScheduler | |
| from torchvision import transforms | |
| from attribution import MappingNetwork | |
| import math | |
| from typing import List | |
| from PIL import Image, ImageChops | |
| import numpy as np | |
| import torch | |
| PRETRAINED_MODEL_NAME_OR_PATH = "./checkpoints/" | |
| def get_image_grid(images: List[Image.Image]) -> Image: | |
| num_images = len(images) | |
| cols = 3#int(math.ceil(math.sqrt(num_images))) | |
| rows = 1#int(math.ceil(num_images / cols)) | |
| width, height = images[0].size | |
| grid_image = Image.new('RGB', (cols * width, rows * height)) | |
| for i, img in enumerate(images): | |
| x = i % cols | |
| y = i // cols | |
| grid_image.paste(img, (x * width, y * height)) | |
| return grid_image | |
| class AttributionModel: | |
| def __init__(self): | |
| is_cuda = False | |
| if torch.cuda.is_available(): | |
| is_cuda = True | |
| scheduler = EulerDiscreteScheduler.from_pretrained('stabilityai/stable-diffusion-2', subfolder="scheduler") | |
| self.pipe = StableDiffusionPipeline.from_pretrained('stabilityai/stable-diffusion-2', scheduler=scheduler)#, safety_checker=None, torch_dtype=torch.float16) | |
| if is_cuda: | |
| self.pipe = self.pipe.to("cuda") | |
| self.resize_transform = transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR) | |
| self.vae = AutoencoderKL.from_pretrained( | |
| 'stabilityai/stable-diffusion-2', subfolder="vae" | |
| ) | |
| self.vae = customize_vae_decoder(self.vae, 128, "deqkv", "all", False, 1.0) | |
| self.mapping_network = MappingNetwork(32, 0, 128, None, num_layers=2, w_avg_beta=None, normalization = False) | |
| from torchvision.models import resnet50, ResNet50_Weights | |
| self.decoding_network = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) | |
| self.decoding_network.fc = torch.nn.Linear(2048,32) | |
| self.vae.decoder.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'vae_decoder.pth'))) | |
| self.mapping_network.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'mapping_network.pth'))) | |
| self.decoding_network.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'decoding_network.pth'))) | |
| if is_cuda: | |
| self.vae = self.vae.to("cuda") | |
| self.mapping_network = self.mapping_network.to("cuda") | |
| self.decoding_network = self.decoding_network.to("cuda") | |
| self.test_norm = transforms.Compose( | |
| [ | |
| transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
| ] | |
| ) | |
| def infer(self, prompt, negative, steps, guidance_scale): | |
| with torch.no_grad(): | |
| out_latents = self.pipe([prompt], negative_prompt=[negative], output_type="latent", num_inference_steps=steps, guidance_scale=guidance_scale).images | |
| image_attr = self.inference_with_attribution(out_latents) | |
| image_attr_pil = self.pipe.numpy_to_pil(image_attr[0]) | |
| image_org = self.inference_without_attribution(out_latents) | |
| image_org_pil = self.pipe.numpy_to_pil(image_org[0]) | |
| # image_diff_pil = self.pipe.numpy_to_pil(image_attr[0] - image_org[0]) | |
| diff_factor = 5 | |
| image_diff_pil = ImageChops.difference(image_org_pil[0], image_attr_pil[0]).convert("RGB", (diff_factor,0,0,0,0,diff_factor,0,0,0,0,diff_factor,0)) | |
| return image_org_pil[0], image_attr_pil[0], image_diff_pil | |
| def inference_without_attribution(self, latents): | |
| latents = 1 / 0.18215 * latents | |
| with torch.no_grad(): | |
| image = self.pipe.vae.decode(latents).sample | |
| image = image.clamp(-1,1) | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.cpu().permute(0, 2, 3, 1).float().numpy() | |
| return image | |
| def get_phis(self, phi_dimension, batch_size ,eps = 1e-8): | |
| phi_length = phi_dimension | |
| b = batch_size | |
| phi = torch.empty(b,phi_length).uniform_(0,1) | |
| return torch.bernoulli(phi) + eps | |
| def inference_with_attribution(self, latents, key=None): | |
| if key==None: | |
| key = self.get_phis(32, 1) | |
| latents = 1 / 0.18215 * latents | |
| with torch.no_grad(): | |
| image = self.vae.decode(latents, self.mapping_network(key.cuda())).sample | |
| image = image.clamp(-1,1) | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.cpu().permute(0, 2, 3, 1).float().numpy() | |
| return image | |
| def postprocess(self, image): | |
| image = self.resize_transform(image) | |
| return image | |
| def detect_key(self, image): | |
| reconstructed_keys = self.decoding_network(self.test_norm((image / 2 + 0.5).clamp(0, 1))) | |
| return reconstructed_keys | |
| attribution_model = AttributionModel() | |
| def get_images(prompt, negative, steps, guidence_scale): | |
| x1, x2, x3 = attribution_model.infer(prompt, negative, steps, guidence_scale) | |
| return [x1, x2, x3] | |
| image_examples = [ | |
| ["A pikachu fine dining with a view to the Eiffel Tower", "low quality", 50, 10], | |
| ["A mecha robot in a favela in expressionist style", "low quality, 3d, photorealistic", 50, 10] | |
| ] | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """<h1 style="text-align: center;"><b>WOUAF: | |
| Weight Modulation for User Attribution and Fingerprinting in Text-to-Image Diffusion Models</b> <br> <a href="https://wouaf.vercel.app">Project Page</a></h1>""") | |
| with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True): | |
| with gr.Column(): | |
| text = gr.Textbox( | |
| label="Enter your prompt", | |
| show_label=False, | |
| max_lines=1, | |
| placeholder="Enter your prompt", | |
| elem_id="prompt-text-input", | |
| ).style( | |
| border=(True, False, True, True), | |
| rounded=(True, False, False, True), | |
| container=False, | |
| ) | |
| negative = gr.Textbox( | |
| label="Enter your negative prompt", | |
| show_label=False, | |
| max_lines=1, | |
| placeholder="Enter a negative prompt", | |
| elem_id="negative-prompt-text-input", | |
| ).style( | |
| border=(True, False, True, True), | |
| rounded=(True, False, False, True), | |
| container=False, | |
| ) | |
| with gr.Row(): | |
| steps = gr.Slider(label="Steps", minimum=45, maximum=55, value=50, step=1) | |
| guidance_scale = gr.Slider( | |
| label="Guidance Scale", minimum=0, maximum=10, value=7.5, step=0.1 | |
| ) | |
| with gr.Row(): | |
| btn = gr.Button(value="Generate Image", full_width=False) | |
| with gr.Row(): | |
| im_2 = gr.Image(type="pil", label="without attribution") | |
| im_3 = gr.Image(type="pil", label="**with** attribution") | |
| im_4 = gr.Image(type="pil", label="pixel-wise difference multiplied by 5") | |
| btn.click(get_images, inputs=[text, negative, steps, guidance_scale], outputs=[im_2, im_3, im_4]) | |
| gr.Examples( | |
| examples=image_examples, | |
| inputs=[text, negative, steps, guidance_scale], | |
| outputs=[im_2, im_3, im_4], | |
| fn=get_images, | |
| cache_examples=True, | |
| ) | |
| gr.HTML( | |
| """ | |
| <div class="footer"> | |
| <p>Pre-trained model by <a href="https://huggingface.co/stabilityai" style="text-decoration: underline;" target="_blank">StabilityAI</a> | |
| </p> | |
| <p> | |
| Fine-tuned by authors for research purpose. | |
| </p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Accordion(label="Ethics & Privacy", open=False): | |
| gr.HTML( | |
| """<div class="acknowledgments"> | |
| <p><h4>Privacy</h4> | |
| We do not collect any images or key data. This demo is designed with sole purpose of fun and reducing misuse of AI. | |
| <p><h4>Biases and content acknowledgment</h4> | |
| This model will have the same biases as Stable Diffusion V2.1 </div> | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |