Spaces:
Runtime error
Runtime error
| import torch | |
| from models import VQVAE, build_vae_var | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, SiglipTextModel | |
| from peft import LoraConfig, get_peft_model | |
| from torchvision.transforms import ToPILImage | |
| import random | |
| import gradio as gr | |
| class SimpleAdapter(nn.Module): | |
| def __init__(self, input_dim=512, hidden_dim=1024, out_dim=1024): | |
| super(SimpleAdapter, self).__init__() | |
| self.layer1 = nn.Linear(input_dim, hidden_dim) | |
| self.norm0 = nn.LayerNorm(input_dim) | |
| self.activation1 = nn.GELU() | |
| self.layer2 = nn.Linear(hidden_dim, out_dim) | |
| self.norm2 = nn.LayerNorm(out_dim) | |
| self._initialize_weights() | |
| def _initialize_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.xavier_uniform_(m.weight, gain=0.001) | |
| nn.init.zeros_(m.bias) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.ones_(m.weight) | |
| nn.init.zeros_(m.bias) | |
| def forward(self, x): | |
| x = self.norm0(x) | |
| x = self.layer1(x) | |
| x = self.activation1(x) | |
| x = self.layer2(x) | |
| x = self.norm2(x) | |
| return x | |
| class InferenceTextVAR(nn.Module): | |
| def __init__(self, pl_checkpoint=None, start_class_id=578, hugging_face_token=None, siglip_model='google/siglip-base-patch16-224', device="cpu", MODEL_DEPTH=16): | |
| super(InferenceTextVAR, self).__init__() | |
| self.device = device | |
| self.class_id = start_class_id | |
| # Define layers | |
| patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16) | |
| self.vae, self.var = build_vae_var( | |
| V=4096, Cvae=32, ch=160, share_quant_resi=4, | |
| device=device, patch_nums=patch_nums, | |
| num_classes=1000, depth=MODEL_DEPTH, shared_aln=False, | |
| ) | |
| self.text_processor = AutoTokenizer.from_pretrained(siglip_model, token=hugging_face_token) | |
| self.siglip_text_encoder = SiglipTextModel.from_pretrained(siglip_model, token=hugging_face_token).to(device) | |
| self.adapter = SimpleAdapter( | |
| input_dim=self.siglip_text_encoder.config.hidden_size, | |
| out_dim=self.var.C # Ensure dimensional consistency | |
| ).to(device) | |
| self.apply_lora_to_var() | |
| if pl_checkpoint is not None: | |
| state_dict = torch.load(pl_checkpoint, map_location="cpu")['state_dict'] | |
| var_state_dict = {k[len('var.'):]: v for k, v in state_dict.items() if k.startswith('var.')} | |
| vae_state_dict = {k[len('vae.'):]: v for k, v in state_dict.items() if k.startswith('vae.')} | |
| adapter_state_dict = {k[len('adapter.'):]: v for k, v in state_dict.items() if k.startswith('adapter.')} | |
| self.var.load_state_dict(var_state_dict) | |
| self.vae.load_state_dict(vae_state_dict) | |
| self.adapter.load_state_dict(adapter_state_dict) | |
| del self.vae.encoder | |
| def apply_lora_to_var(self): | |
| """ | |
| Applies LoRA (Low-Rank Adaptation) to the VAR model. | |
| """ | |
| def find_linear_module_names(model): | |
| linear_module_names = [] | |
| for name, module in model.named_modules(): | |
| if isinstance(module, nn.Linear): | |
| linear_module_names.append(name) | |
| return linear_module_names | |
| linear_module_names = find_linear_module_names(self.var) | |
| lora_config = LoraConfig( | |
| r=8, | |
| lora_alpha=32, | |
| target_modules=linear_module_names, | |
| lora_dropout=0.05, | |
| bias="none", | |
| ) | |
| self.var = get_peft_model(self.var, lora_config) | |
| def generate_image(self, text, beta=1, seed=None, more_smooth=False, top_k=0, top_p=0.5): | |
| if seed is None: | |
| seed = random.randint(0, 2**32 - 1) | |
| inputs = self.text_processor([text], padding="max_length", return_tensors="pt").to(self.device) | |
| outputs = self.siglip_text_encoder(**inputs) | |
| pooled_output = outputs.pooler_output # pooled (EOS token) states | |
| pooled_output = F.normalize(pooled_output, p=2, dim=-1) # Normalize delta condition | |
| cond_delta = F.normalize(pooled_output, p=2, dim=-1).to(self.device) # Use correct device | |
| cond_delta = self.adapter(cond_delta) | |
| cond_delta = F.normalize(cond_delta, p=2, dim=-1) # Normalize delta condition | |
| generated_images = self.var.autoregressive_infer_cfg( | |
| B=1, | |
| label_B=self.class_id, | |
| delta_condition=cond_delta[:1], | |
| beta=beta, | |
| alpha=1, | |
| top_k=top_k, | |
| top_p=top_p, | |
| more_smooth=more_smooth, | |
| g_seed=seed | |
| ) | |
| image = ToPILImage()(generated_images[0].cpu()) | |
| return image | |
| if __name__ == '__main__': | |
| # Initialize the model | |
| checkpoint = 'VARtext_v1.pth' # Replace with your actual checkpoint path | |
| device = 'cpu' if not torch.cuda.is_available() else 'cuda' | |
| model = InferenceTextVAR(device=device) | |
| model.load_state_dict(torch.load(checkpoint, map_location=device)) | |
| model.to(device) | |
| def generate_image_gradio(text, beta=1.0, seed=None, more_smooth=False, top_k=0, top_p=0.9): | |
| print(f"Generating image for text: {text}\n" | |
| f"beta: {beta}\n" | |
| f"seed: {seed}\n" | |
| f"more_smooth: {more_smooth}\n" | |
| f"top_k: {top_k}\n" | |
| f"top_p: {top_p}\n") | |
| image = model.generate_image(text, beta=beta, seed=seed, more_smooth=more_smooth, top_k=int(top_k), top_p=top_p) | |
| return image | |
| with gr.Blocks(css=""" | |
| .project-item {margin-bottom: 30px;} | |
| .project-description {margin-top: 20px;} | |
| .github-button, .huggingface-button, .wandb-button { | |
| display: inline-block; margin-left: 10px; text-decoration: none; font-size: 14px; | |
| padding: 5px 10px; background-color: #f0f0f0; border-radius: 5px; color: black; | |
| } | |
| .project-content {display: flex; flex-direction: row;} | |
| .project-description {flex: 2; padding-right: 20px;} | |
| .project-options-image {flex: 1;} | |
| .funko-image {width: 100%; max-width: 300px;} | |
| """) as demo: | |
| gr.Markdown(""" | |
| # PopYou2 - VAR Text | |
| <!-- Project Links --> | |
| [](https://github.com/amit154154/VAR_clip) | |
| [](https://api.wandb.ai/links/amit154154/cqccmfsl) | |
| ## Project Explanation | |
| - **Dataset Generation:** Generated a comprehensive dataset of approximately 100,000 Funko Pop! images with detailed prompts using [SDXL Turbo](https://huggingface.co/stabilityai/sdxl-turbo) for high-quality data creation. | |
| - **Model Fine-tuning:** Fine-tuned the [Visual AutoRegressive (VAR)](https://arxiv.org/abs/2404.02905) model, pretrained on ImageNet, to adapt it for Funko Pop! generation by injecting a custom embedding representing the "doll" class. | |
| - **Adapter Training:** Trained an adapter with the frozen [SigLIP image encoder](https://github.com/FoundationVision/VAR) and a lightweight LoRA module to map image embeddings to text representation in a large language model. | |
| - **Text-to-Image Generation:** Enabled text-to-image generation by replacing the SigLIP image encoder with its text encoder, retaining frozen components such as the VAE and generator for efficiency and quality. | |
| ## Generate Your Own Funko Pop! | |
| """) | |
| with gr.Tab("Generate Image"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| text_input = gr.Textbox(label="Input Text", placeholder="Enter a description for your Funko Pop!") | |
| beta_input = gr.Slider(label="Beta", minimum=0.0, maximum=2.5, step=0.05, value=1.0) | |
| seed_input = gr.Number(label="Seed", value=None) | |
| more_smooth_input = gr.Checkbox(label="More Smooth", value=False) | |
| top_k_input = gr.Number(label="Top K", value=0) | |
| top_p_input = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, step=0.01, value=0.5) | |
| generate_button = gr.Button("Generate Image") | |
| with gr.Column(scale=1): | |
| image_output = gr.Image(label="Generated Image") | |
| generate_button.click( | |
| generate_image_gradio, | |
| inputs=[text_input, beta_input, seed_input, more_smooth_input, top_k_input, top_p_input], | |
| outputs=image_output | |
| ) | |
| gr.Markdown("## Examples") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Example 1") | |
| gr.Markdown("A Funko Pop figure of a yellow robot Tom Cruise with headphones on a white background") | |
| example1_image = gr.Image(value="examples/tom_cruise_robot.png") # Replace with the actual path | |
| with gr.Column(): | |
| gr.Markdown("### Example 2") | |
| gr.Markdown("A Funko Pop figure of an alien Scarlett Johansson holding a shield on a white background") | |
| example2_image = gr.Image(value="examples/alien_Scarlett_Johansson.png") # Replace with the actual path | |
| with gr.Column(): | |
| gr.Markdown("### Example 3") | |
| gr.Markdown("A Funko Pop figure of a woman with a hat and pink long hair and blue dress on a white background") | |
| example3_image = gr.Image(value="examples/woman_pink.png") # Replace with the actual path | |
| gr.Markdown(""" | |
| ## Customize Your Funko Pop! | |
| Build your own Funko Pop! by selecting options below and clicking "Generate Custom Funko Pop!". | |
| """) | |
| def update_custom_image(famous_name, character, action): | |
| # Build the prompt based on the selections | |
| parts = [] | |
| if famous_name != "None": | |
| parts.append(f"a Funko Pop figure of {famous_name}") | |
| else: | |
| parts.append("a Funko Pop figure") | |
| if character != "None": | |
| parts.append(f"styled as a {character}") | |
| if action != "None": | |
| parts.append(f"performing {action}") | |
| parts.append("on a white background") | |
| prompt = ", ".join(parts) | |
| image = model.generate_image(prompt) | |
| return image | |
| famous_name_input = gr.Dropdown(choices=["None", "Donald Trump", "Johnny Depp", "Oprah Winfrey,Lebron James"], label="Famous Name", value="None") | |
| character_input = gr.Dropdown(choices=["None", "Alien", "Robot"], label="Character", value="None") | |
| action_input = gr.Dropdown(choices=["None", "Playing the Guitar", "Holding the Sword","wearing headphone"], label="Action", value="None") | |
| custom_generate_button = gr.Button("Generate Custom Funko Pop!") | |
| custom_image_output = gr.Image(label="Custom Funko Pop!") | |
| custom_generate_button.click( | |
| update_custom_image, | |
| inputs=[famous_name_input, character_input, action_input], | |
| outputs=custom_image_output | |
| ) | |
| demo.launch() |