Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| from gradio_client import Client, handle_file | |
| import torch | |
| import spaces | |
| from diffusers import Lumina2Pipeline | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| if torch.cuda.is_available(): | |
| torch_dtype = torch.bfloat16 | |
| else: | |
| torch_dtype = torch.float32 | |
| def set_client_for_session(request: gr.Request): | |
| x_ip_token = request.headers['x-ip-token'] | |
| # The "gradio/text-to-image" space is a ZeroGPU space | |
| # return Client("stzhao/LeX-Enhancer", headers={"X-IP-Token": x_ip_token}) | |
| return Client("stzhao/LeX-Enhancer") | |
| # Load models | |
| def load_models(): | |
| pipe = Lumina2Pipeline.from_pretrained( | |
| "X-ART/LeX-Lumina", | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| pipe.to("cuda") | |
| return pipe | |
| def prompt_enhance(client, image_caption, text_caption): | |
| combined_caption, enhanced_caption = client.predict(image_caption, text_caption, api_name="/generate_enhanced_caption") | |
| return combined_caption, enhanced_caption | |
| pipe = load_models() | |
| # def truncate_caption_by_tokens(caption, max_tokens=256): | |
| # """Truncate the caption to fit within the max token limit""" | |
| # tokens = tokenizer.encode(caption) | |
| # if len(tokens) > max_tokens: | |
| # truncated_tokens = tokens[:max_tokens] | |
| # caption = tokenizer.decode(truncated_tokens, skip_special_tokens=True) | |
| # print(f"Caption was truncated from {len(tokens)} tokens to {max_tokens} tokens") | |
| # return caption | |
| def generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale): | |
| # pipe.to("cuda") | |
| pipe.enable_model_cpu_offload() | |
| """Generate image using LeX-Lumina""" | |
| # Truncate the caption if it's too long | |
| # enhanced_caption = truncate_caption_by_tokens(enhanced_caption, max_tokens=256) | |
| generator = torch.Generator("cpu").manual_seed(seed) if seed != 0 else None | |
| image = pipe( | |
| enhanced_caption, | |
| height=1024, | |
| width=1024, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| cfg_trunc_ratio=1, | |
| cfg_normalization=True, | |
| max_sequence_length=256, | |
| generator=generator, | |
| system_prompt="You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts.", | |
| ).images[0] | |
| print(image) | |
| pipe.to("cpu") | |
| torch.cuda.empty_cache() | |
| return image | |
| # @spaces.GPU(duration=130) | |
| def run_pipeline(image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer, client): | |
| """Run the complete pipeline from captions to final image""" | |
| combined_caption = f"{image_caption}, with the text on it: {text_caption}." | |
| if enable_enhancer: | |
| # combined_caption, enhanced_caption = generate_enhanced_caption(image_caption, text_caption) | |
| combined_caption, enhanced_caption = prompt_enhance(client, image_caption, text_caption) | |
| print(f"enhanced caption:\n{enhanced_caption}") | |
| else: | |
| enhanced_caption = combined_caption | |
| image = generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale) | |
| return image, combined_caption, enhanced_caption | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| client = gr.State() | |
| gr.Markdown("# LeX-Enhancer & LeX-Lumina Demo") | |
| gr.Markdown("## Project Page: https://zhaoshitian.github.io/lexart/") | |
| gr.Markdown("Generate enhanced captions from simple image and text descriptions, then create images with LeX-Lumina") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_caption = gr.Textbox( | |
| lines=2, | |
| label="Image Caption", | |
| placeholder="Describe the visual content of the image", | |
| value="A picture of a group of people gathered in front of a world map" | |
| ) | |
| text_caption = gr.Textbox( | |
| lines=2, | |
| label="Text Caption", | |
| placeholder="Describe any text that should appear in the image", | |
| value="\"Communicate\" in purple, \"Execute\" in yellow" | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| enable_enhancer = gr.Checkbox( | |
| label="Enable LeX-Enhancer", | |
| value=True, | |
| info="When enabled, the caption will be enhanced before image generation" | |
| ) | |
| seed = gr.Slider( | |
| minimum=0, | |
| maximum=100000, | |
| value=0, | |
| step=1, | |
| label="Seed (0 for random)" | |
| ) | |
| num_inference_steps = gr.Slider( | |
| minimum=20, | |
| maximum=100, | |
| value=40, | |
| step=1, | |
| label="Number of Inference Steps" | |
| ) | |
| guidance_scale = gr.Slider( | |
| minimum=1.0, | |
| maximum=10.0, | |
| value=7.5, | |
| step=0.1, | |
| label="Guidance Scale" | |
| ) | |
| submit_btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Generated Image") | |
| combined_caption_box = gr.Textbox( | |
| label="Combined Caption", | |
| interactive=False | |
| ) | |
| enhanced_caption_box = gr.Textbox( | |
| label="Enhanced Caption" if enable_enhancer.value else "Final Caption", | |
| interactive=False, | |
| lines=5 | |
| ) | |
| # Example prompts | |
| examples = [ | |
| ["A modern office workspace", "\"Innovation\" in bold blue letters at the center"], | |
| ["A beach sunset scene", "\"Relax\" in cursive white text in the corner"], | |
| ["A futuristic city skyline", "\"The Future is Now\" in neon pink glowing letters"] | |
| ] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[image_caption, text_caption], | |
| label="Example Inputs" | |
| ) | |
| # Update the label of enhanced_caption_box based on checkbox state | |
| def update_caption_label(enable_enhancer): | |
| return gr.Textbox(label="Enhanced Caption" if enable_enhancer else "Final Caption") | |
| enable_enhancer.change( | |
| fn=update_caption_label, | |
| inputs=enable_enhancer, | |
| outputs=enhanced_caption_box | |
| ) | |
| submit_btn.click( | |
| fn=run_pipeline, | |
| inputs=[image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer, client], | |
| outputs=[output_image, combined_caption_box, enhanced_caption_box] | |
| ) | |
| demo.load(set_client_for_session, None, client) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) |