character-chain / live_preview_helper.py
Jensin's picture
"updated app.py"
d5a9cce
raw
history blame
2.03 kB
import torch
def sd_live_preview(pipe, prompt, num_inference_steps=30, width=512, height=512, **kwargs):
"""
Generator function for live-preview-like SD inference.
Yields intermediate images at intervals, ending with the final.
Note: True live preview requires a custom pipeline or hacky callback patching;
this simulates by breaking up steps into separate forward passes.
"""
preview_steps = [6, 12, 20, num_inference_steps]
# Remove duplicates and ensure in-bounds
preview_steps = sorted(set(s for s in preview_steps if s <= num_inference_steps))
image = None
for steps in preview_steps:
with torch.autocast("cuda" if torch.cuda.is_available() else "cpu"):
result = pipe(
prompt,
num_inference_steps=steps,
width=width,
height=height,
**kwargs
)
image = result.images[0]
yield image # Each image is a PIL.Image
# Optional: ensure final image is always last
if preview_steps[-1] != num_inference_steps:
with torch.autocast("cuda" if torch.cuda.is_available() else "cpu"):
result = pipe(
prompt,
num_inference_steps=num_inference_steps,
width=width,
height=height,
**kwargs
)
image = result.images[0]
yield image
# live_preview_helpers.py
def flux_pipe_call_that_returns_an_iterable_of_images(
self,
prompt,
guidance_scale=7.5,
num_inference_steps=20,
width=512,
height=512,
generator=None,
output_type="pil",
good_vae=None
):
"""
Yields a single Stable Diffusion image matching the prompt.
"""
image = self(
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
output_type=output_type,
).images[0]
yield image