aducsdr's picture
Update app.py
c9b258b verified
import gradio as gr
import torch
import spaces
import os
import numpy as np
from PIL import Image
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download
# --- Bloco de Download Automático do Modelo ---
WEIGHTS_DIR = "./pretrained_weights/ByteMorpher"
MODEL_FILENAME = "dit.safetensors"
MODEL_PATH = os.path.join(WEIGHTS_DIR, MODEL_FILENAME)
os.makedirs(WEIGHTS_DIR, exist_ok=True)
if not os.path.exists(MODEL_PATH):
print(f"Modelo não encontrado em {MODEL_PATH}. Baixando do Hugging Face Hub...")
try:
hf_hub_download(
repo_id="ByteDance-Seed/BM-Model",
filename=MODEL_FILENAME,
local_dir=WEIGHTS_DIR,
local_dir_use_symlinks=False
)
print("Download do modelo concluído com sucesso.")
except Exception as e:
print(f"Ocorreu um erro durante o download do modelo: {e}")
else:
print(f"Modelo já existe em {MODEL_PATH}. Pulando o download.")
# --- Fim do Bloco de Download ---
from image_datasets.dataset import image_resize
args = OmegaConf.load("inference_configs/inference.yaml")
device = torch.device("cuda")
dtype = torch.bfloat16
sampler = None
def generate(image: Image.Image, edit_prompt: str):
from src.flux.xflux_pipeline import XFluxSampler
global sampler
if sampler is None:
print("Inicializando o XFluxSampler com a configuração completa...")
sampler = XFluxSampler(
device=device,
ip_loaded=args.get('use_ip', False),
spatial_condition=args.get('use_spatial_condition', True),
share_position_embedding=args.get('share_position_embedding', True),
use_share_weight_referencenet=args.get('use_share_weight_referencenet', False),
double_block_refnet=args.get('double_block_refnet', True),
single_block_refnet=args.get('single_block_refnet', False)
)
# --- CORREÇÃO: Redimensiona a imagem de entrada para corresponder às dimensões de saída ---
target_width = (args.sample_width // 32) * 32
target_height = (args.sample_height // 32) * 32
img = image.resize((target_width, target_height), Image.Resampling.LANCZOS)
img = torch.from_numpy((np.array(img) / 127.5) - 1)
img = img.permute(2, 0, 1).unsqueeze(0).to(device, dtype=dtype)
use_image_conditioning = args.use_spatial_condition or args.use_share_weight_referencenet
# Gera um seed aleatório se for -1
seed = args.seed if args.seed != -1 else np.random.randint(0, 2**32 - 1)
result = sampler(
prompt=edit_prompt,
width=args.sample_width,
height=args.sample_height,
num_steps=args.sample_steps,
image_prompt=None,
true_gs=args.cfg_scale,
seed=seed,
ip_scale=args.ip_scale if args.use_ip else 1.0,
source_image=img if use_image_conditioning else None,
)
return result
def get_samples():
sample_list = [
{
"image": "assets/0_camera_zoom/20486354.png",
"edit_prompt": "Zoom in on the coral and add a small blue fish in the background.",
},
{
"image": "assets/0_camera_zoom/168836781.png",
"edit_prompt": "The camera moves slightly closer to the person in the red raincoat.",
},
{
"image": "assets/0_camera_zoom/195278796.png",
"edit_prompt": "A blue sign with white text and a white sign with green text appear at the bottom of the frame, and the camera zooms out.",
},
{
"image": "assets/0_camera_zoom/242167914.png",
"edit_prompt": "The person in the foreground moves further away from the camera.",
},
{
"image": "assets/1_camera_motion/205012085.png",
"edit_prompt": "The camera moves slightly downward.",
},
{
"image": "assets/1_camera_motion/238430441.png",
"edit_prompt": "The camera angle changes, tilting slightly to the left and downward.",
},
{
"image": "assets/2_object_motion/34440751.png",
"edit_prompt": "The train moves forward, and a station building appears on the left side of the frame.",
},
{
"image": "assets/2_object_motion/47140330.png",
"edit_prompt": "The train on the bridge disappears.",
},
{
"image": "assets/2_object_motion/65531461.png",
"edit_prompt": "The jet bridge retracts from the airplane.",
},
{
"image": "assets/2_object_motion/236575633.png",
"edit_prompt": "The puppy on the left moves its head to face forward.",
},
{
"image": "assets/3_human_motion/473660.png",
"edit_prompt": "The person's arms are raised higher in the second frame.",
},
{
"image": "assets/3_human_motion/114875262.png",
"edit_prompt": "The person moves from a prone position with arms extended forward to a kneeling position on the mat.",
},
{
"image": "assets/3_human_motion/133541209.png",
"edit_prompt": "The person's right arm changes from being bent with their hand near their head to giving a thumbs-up gesture.",
},
{
"image": "assets/3_human_motion/152522070.png",
"edit_prompt": "The person tilts their head downwards.",
},
{
"image": "assets/3_human_motion/158685768.png",
"edit_prompt": "The person turns their head to the right.",
},
{
"image": "assets/4_interaction/142739045.png",
"edit_prompt": "Milk is poured into the bowl of cereal, and the glass is lowered and partially emptied.",
},
{
"image": "assets/4_interaction/146371498.png",
"edit_prompt": "The hand with the glove moves closer to the black and wooden object, lifting it off the surface.",
},
{
"image": "assets/4_interaction/148905535.png",
"edit_prompt": "The hand holding the pen moves downwards, and the pen is no longer visible.",
},
{
"image": "assets/4_interaction/151416962.png",
"edit_prompt": "The person lowers the phone from their ear and looks at it.",
},
{
"image": "assets/4_interaction/165994252.png",
"edit_prompt": "The person lifts the box off the table.",
},
{
"image": "assets/4_interaction/220356955.png",
"edit_prompt": "The person lowers the cup and places it on the table.",
},
{
"image": "assets/4_interaction/231403861.png",
"edit_prompt": "The person tilts their head to the right and raises the pineapple closer to their face.",
},
{
"image": "assets/4_interaction/234177339.png",
"edit_prompt": "The person changes their hand position from holding their face to holding a phone.",
},
]
return [
[
Image.open(sample["image"]).resize((512, 512)),
sample["edit_prompt"],
]
for sample in sample_list
]
def create_app():
with gr.Blocks() as app:
gr.HTML(
"""
<div style="text-align: center;">
<h2>ByteMorpher</h2>
<a href="https://arxiv.org/abs/2506.03107" target="_blank"><img src="https://img.shields.io/badge/arXiv-Paper-red" style="display:inline-block;"></a>
<a href="https://boese0601.github.io/bytemorph/" target="_blank"><img src="https://img.shields.io/badge/Project-Website-blue" style="display:inline-block;"></a>
<a href="https://github.com/ByteDance-Seed/BM-code" target="_blank"><img src="https://img.shields.io/github/stars/Boese0601/ByteMorph?label=GitHub%20%E2%98%85&logo=github&color=green" style="display:inline-block;"></a>
<a href="https://huggingface.co/datasets/ByteDance-Seed/BM-6M" target="_blank"><img src="https://img.shields.io/badge/🤗%20Hugging%20Face-Dataset-yellow" style="display:inline-block;"></a>
<a href="https://huggingface.co/datasets/ByteDance-Seed/BM-6M-Demo" target="_blank"><img src="https://img.shields.io/badge/🤗%20Hugging%20Face-Dataset_Demo-yellow" style="display:inline-block;"></a>
<a href="https://huggingface.co/datasets/ByteDance-Seed/BM-Bench" target="_blank"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace%20-Benchmark-yellow" style="display:inline-block;"></a>
<a href="https://huggingface.co/ByteDance-Seed/BM-Model" target="_blank"><img src="https://img.shields.io/badge/🤗%20Hugging%20Face%20-Model-yellow" style="display:inline-block;"></a>
</div>
"""
)
with gr.Row(equal_height=False):
with gr.Column(variant="panel", elem_classes="inputPanel"):
original_image = gr.Image(
type="pil", label="Condition Image", width=300, elem_id="input"
)
edit_prompt = gr.Textbox(lines=2, label="Edit Prompt", elem_id="edit_prompt")
submit_btn = gr.Button("Run", elem_id="submit_btn")
with gr.Column(variant="panel", elem_classes="outputPanel"):
output_image = gr.Image(type="pil", elem_id="output")
with gr.Row():
examples = gr.Examples(
examples=get_samples(),
inputs=[original_image, edit_prompt],
label="Examples",
)
submit_btn.click(
fn=generate,
inputs=[original_image, edit_prompt],
outputs=output_image,
)
gr.HTML(
"""
<div style="text-align: center;">
* This demo's template was modified from <a href="https://arxiv.org/abs/2411.15098" target="_blank">OminiControl</a>.
</div>
"""
)
return app
if __name__ == "__main__":
app = create_app()
app.launch(debug=False, share=False)