Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| from PIL import Image | |
| import torch | |
| from torch.amp import autocast | |
| from transformers import AutoTokenizer, AutoModel | |
| from models.gen_pipeline import NextStepPipeline | |
| HF_HUB = "stepfun-ai/NextStep-1-Large" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| tokenizer = AutoTokenizer.from_pretrained(HF_HUB, local_files_only=False, trust_remote_code=True) | |
| model = AutoModel.from_pretrained( | |
| HF_HUB, | |
| local_files_only=False, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, | |
| ).to(device) | |
| pipeline = NextStepPipeline(tokenizer=tokenizer, model=model).to(device=device, dtype=torch.bfloat16) | |
| MAX_SEED = np.iinfo(np.int16).max | |
| DEFAULT_POSITIVE_PROMPT = None | |
| DEFAULT_NEGATIVE_PROMPT = None | |
| DEFAULT_CFG = 7.5 | |
| def _ensure_pil(x): | |
| """Ensure returned image is a PIL.Image.Image.""" | |
| if isinstance(x, Image.Image): | |
| return x | |
| import numpy as np | |
| if hasattr(x, "detach"): | |
| x = x.detach().float().clamp(0, 1).cpu().numpy() | |
| if isinstance(x, np.ndarray): | |
| if x.dtype != np.uint8: | |
| x = (x * 255.0).clip(0, 255).astype(np.uint8) | |
| if x.ndim == 3 and x.shape[0] in (1, 3, 4): # CHW -> HWC | |
| x = np.moveaxis(x, 0, -1) | |
| return Image.fromarray(x) | |
| raise TypeError("Unsupported image type returned by pipeline.") | |
| def infer_core(prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt, progress): | |
| """Core inference logic without GPU decorators.""" | |
| if prompt in [None, ""]: | |
| gr.Warning("⚠️ Please enter a prompt!") | |
| return None | |
| with autocast(device_type=("cuda" if device == "cuda" else "cpu"), dtype=torch.bfloat16): | |
| imgs = pipeline.generate_image( | |
| prompt, | |
| hw=(int(height), int(width)), | |
| num_images_per_caption=1, | |
| positive_prompt=positive_prompt, | |
| negative_prompt=negative_prompt, | |
| cfg=float(cfg), | |
| cfg_img=1.0, | |
| cfg_schedule="constant", | |
| use_norm=False, | |
| num_sampling_steps=int(num_inference_steps), | |
| timesteps_shift=1.0, | |
| seed=int(seed), | |
| progress=True, | |
| ) | |
| return _ensure_pil(imgs[0]) | |
| # Tier 1: Very small images with few steps | |
| def infer_tiny(prompt=None, seed=0, width=512, height=512, num_inference_steps=24, cfg=DEFAULT_CFG, | |
| positive_prompt=DEFAULT_POSITIVE_PROMPT, negative_prompt=DEFAULT_NEGATIVE_PROMPT, | |
| progress=gr.Progress(track_tqdm=True)): | |
| return infer_core(prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt, progress) | |
| # Tier 2: Small to medium images with standard steps | |
| def infer_fast(prompt=None, seed=0, width=512, height=512, num_inference_steps=24, cfg=DEFAULT_CFG, | |
| positive_prompt=DEFAULT_POSITIVE_PROMPT, negative_prompt=DEFAULT_NEGATIVE_PROMPT, | |
| progress=gr.Progress(track_tqdm=True)): | |
| return infer_core(prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt, progress) | |
| # Tier 3: Standard generation for most common cases | |
| def infer_std(prompt=None, seed=0, width=512, height=512, num_inference_steps=28, cfg=DEFAULT_CFG, | |
| positive_prompt=DEFAULT_POSITIVE_PROMPT, negative_prompt=DEFAULT_NEGATIVE_PROMPT, | |
| progress=gr.Progress(track_tqdm=True)): | |
| return infer_core(prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt, progress) | |
| # Tier 4: Larger images or more steps | |
| def infer_long(prompt=None, seed=0, width=512, height=512, num_inference_steps=36, cfg=DEFAULT_CFG, | |
| positive_prompt=DEFAULT_POSITIVE_PROMPT, negative_prompt=DEFAULT_NEGATIVE_PROMPT, | |
| progress=gr.Progress(track_tqdm=True)): | |
| return infer_core(prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt, progress) | |
| # Tier 5: Maximum quality with many steps | |
| def infer_max(prompt=None, seed=0, width=512, height=512, num_inference_steps=45, cfg=DEFAULT_CFG, | |
| positive_prompt=DEFAULT_POSITIVE_PROMPT, negative_prompt=DEFAULT_NEGATIVE_PROMPT, | |
| progress=gr.Progress(track_tqdm=True)): | |
| return infer_core(prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt, progress) | |
| # Improved JS dispatcher with better calculation logic | |
| js_dispatch = """ | |
| function(width, height, steps){ | |
| const w = Number(width); | |
| const h = Number(height); | |
| const s = Number(steps); | |
| // Calculate total pixels and complexity score | |
| const pixels = w * h; | |
| const megapixels = pixels / 1000000; | |
| // Complexity score combines image size and steps | |
| // Base: ~0.5 seconds per megapixel per step | |
| const complexity = megapixels * s; | |
| let target = 'btn-std'; // Default | |
| // Select appropriate tier based on complexity | |
| if (pixels <= 256*256 && s <= 20) { | |
| // Very small images with few steps | |
| target = 'btn-tiny'; | |
| } else if (complexity < 5) { | |
| // Small images or few steps (e.g., 384x384 @ 24 steps = 3.5) | |
| target = 'btn-fast'; | |
| } else if (complexity < 8) { | |
| // Standard generation (e.g., 512x512 @ 28 steps = 7.3) | |
| target = 'btn-std'; | |
| } else if (complexity < 12) { | |
| // Larger or more steps (e.g., 512x512 @ 40 steps = 10.5) | |
| target = 'btn-long'; | |
| } else { | |
| // Maximum complexity | |
| target = 'btn-max'; | |
| } | |
| // Special cases: override based on extreme values | |
| if (s >= 45) { | |
| target = 'btn-max'; // Many steps always need more time | |
| } else if (pixels >= 512*512 && s >= 35) { | |
| target = 'btn-long'; // Large images with many steps | |
| } | |
| console.log(`Resolution: ${w}x${h}, Steps: ${s}, Complexity: ${complexity.toFixed(2)}, Selected: ${target}`); | |
| const b = document.getElementById(target); | |
| if (b) b.click(); | |
| } | |
| """ | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 800px; | |
| } | |
| /* Hide the dispatcher buttons */ | |
| #btn-tiny, #btn-fast, #btn-std, #btn-long, #btn-max { | |
| display: none !important; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown("# NextStep-1-Large — Image generation") | |
| with gr.Row(): | |
| prompt = gr.Text(label="Prompt", show_label=False, max_lines=2, placeholder="Enter your prompt", | |
| container=False) | |
| run_button = gr.Button("Run", scale=0, variant="primary") | |
| cancel_button = gr.Button("Cancel", scale=0, variant="secondary") | |
| with gr.Row(): | |
| with gr.Accordion("Advanced Settings", open=True): | |
| positive_prompt = gr.Text(label="Positive Prompt", show_label=True, | |
| placeholder="Optional: add positives") | |
| negative_prompt = gr.Text(label="Negative Prompt", show_label=True, | |
| placeholder="Optional: add negatives") | |
| with gr.Row(): | |
| seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=3407) | |
| num_inference_steps = gr.Slider(label="Sampling steps", minimum=10, maximum=50, step=1, value=28) | |
| with gr.Row(): | |
| width = gr.Slider(label="Width", minimum=256, maximum=512, step=64, value=512) | |
| height = gr.Slider(label="Height", minimum=256, maximum=512, step=64, value=512) | |
| cfg = gr.Slider(label="CFG (guidance scale)", minimum=0.0, maximum=20.0, step=0.5, value=DEFAULT_CFG, | |
| info="Higher = closer to text, lower = more creative") | |
| with gr.Row(): | |
| result_1 = gr.Image(label="Result", format="png", interactive=False) | |
| # Hidden dispatcher buttons | |
| with gr.Row(visible=False): | |
| btn_tiny = gr.Button(visible=False, elem_id="btn-tiny") | |
| btn_fast = gr.Button(visible=False, elem_id="btn-fast") | |
| btn_std = gr.Button(visible=False, elem_id="btn-std") | |
| btn_long = gr.Button(visible=False, elem_id="btn-long") | |
| btn_max = gr.Button(visible=False, elem_id="btn-max") | |
| examples = [ | |
| [ | |
| "Studio portrait of an elderly sailor with a weathered face, dramatic Rembrandt lighting, shallow depth of field", | |
| 101, 512, 512, 32, 7.5, | |
| "photorealistic, sharp eyes, detailed skin texture, soft rim light, 85mm lens", | |
| "over-smoothed skin, plastic look, extra limbs, watermark"], | |
| ["Isometric cozy coffee shop interior with hanging plants and warm Edison bulbs", | |
| 202, 512, 384, 30, 8.5, | |
| "isometric view, clean lines, stylized, warm ambience, detailed furniture", | |
| "text, logo, watermark, perspective distortion"], | |
| ["Ultra-wide desert canyon at golden hour with long shadows and dust in the air", | |
| 303, 512, 320, 28, 7.0, | |
| "cinematic, volumetric light, natural colors, high dynamic range", | |
| "over-saturated, haze artifacts, blown highlights"], | |
| ["Oil painting of a stormy sea with a lighthouse, thick impasto brushwork", | |
| 707, 384, 512, 34, 7.0, | |
| "textured canvas, visible brush strokes, dramatic sky, moody lighting", | |
| "smooth digital look, airbrush, neon colors"], | |
| ] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt], | |
| label="Click & Fill Examples (Exact Size)", | |
| ) | |
| # Wire up the dispatcher buttons to their respective functions | |
| ev_tiny = btn_tiny.click(infer_tiny, | |
| inputs=[prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, | |
| negative_prompt], | |
| outputs=[result_1]) | |
| ev_fast = btn_fast.click(infer_fast, | |
| inputs=[prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, | |
| negative_prompt], | |
| outputs=[result_1]) | |
| ev_std = btn_std.click(infer_std, | |
| inputs=[prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, | |
| negative_prompt], | |
| outputs=[result_1]) | |
| ev_long = btn_long.click(infer_long, | |
| inputs=[prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, | |
| negative_prompt], | |
| outputs=[result_1]) | |
| ev_max = btn_max.click(infer_max, | |
| inputs=[prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, | |
| negative_prompt], | |
| outputs=[result_1]) | |
| # Trigger JS dispatcher on run button or prompt submit | |
| run_button.click(None, inputs=[width, height, num_inference_steps], outputs=[], js=js_dispatch) | |
| prompt.submit(None, inputs=[width, height, num_inference_steps], outputs=[], js=js_dispatch) | |
| # Cancel button cancels all possible events | |
| cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[ev_tiny, ev_fast, ev_std, ev_long, ev_max]) | |
| if __name__ == "__main__": | |
| demo.launch() |