HuMo_local / app.py
alexnasa's picture
Update app.py
b6e90a6 verified
import spaces
import gradio as gr
import sys
import os
import subprocess
import uuid
import shutil
from tqdm import tqdm
from huggingface_hub import snapshot_download, list_repo_files, hf_hub_download
import importlib, site
# Re-discover all .pth/.egg-link files
for sitedir in site.getsitepackages():
site.addsitedir(sitedir)
# Clear caches so importlib will pick up new modules
importlib.invalidate_caches()
def sh(cmd): subprocess.check_call(cmd, shell=True)
flash_attention_installed = False
try:
flash_attention_wheel = hf_hub_download(
repo_id="alexnasa/flash-attn-3",
repo_type="model",
filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
)
sh(f"pip install {flash_attention_wheel}")
print("Attempting to download and install FlashAttention wheel...")
import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
flash_attention_installed = True
except Exception as e:
print(f"⚠️ Could not install FlashAttention: {e}")
print("Continuing without FlashAttention...")
try:
te_wheel = hf_hub_download(
repo_id="alexnasa/transformer_engine_wheels",
repo_type="model",
filename="transformer_engine-2.5.0+f05f12c9-cp310-cp310-linux_x86_64.whl",
)
sh(f"pip install {te_wheel}")
print("Attempting to download and install Transformer Engine wheel...")
import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
except Exception as e:
print(f"⚠️ Could not install Transformer Engine : {e}")
print("Continuing without Transformer Engine ...")
import torch
print(f"Torch version: {torch.__version__}")
print(f"FlashAttention available: {flash_attention_installed}")
import tempfile
from pathlib import Path
from torch._inductor.runtime.runtime_utils import cache_dir as _inductor_cache_dir
from huggingface_hub import HfApi
snapshot_download(repo_id="bytedance-research/HuMo", local_dir="./weights/HuMo")
snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-1.3B", local_dir="./weights/Wan2.1-T2V-1.3B")
snapshot_download(repo_id="openai/whisper-large-v3", local_dir="./weights/whisper-large-v3")
os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/proprocess_results"
path_to_insert = "humo"
if path_to_insert not in sys.path:
sys.path.insert(0, path_to_insert)
from common.config import load_config, create_object
config = load_config(
"./humo/configs/inference/generate.yaml",
[
"dit.sp_size=1",
"generation.frames=97",
"generation.scale_t=5.5",
"generation.scale_a=5.0",
"generation.mode=TIA",
"generation.height=480",
"generation.width=832",
],
)
runner = create_object(config)
os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", f"{os.getcwd()}/torchinductor_space") # or another writable path
def restore_inductor_cache_from_hub(repo_id: str, filename: str = "torch_compile_cache.zip",
path_in_repo: str = "inductor_cache", repo_type: str = "model",
hf_token: str | None = None):
cache_root = Path(_inductor_cache_dir()).resolve()
cache_root.mkdir(parents=True, exist_ok=True)
zip_path = hf_hub_download(repo_id=repo_id, filename=f"{path_in_repo}/{filename}",
repo_type=repo_type, token=hf_token)
shutil.unpack_archive(zip_path, extract_dir=str(cache_root))
print(f"✓ Restored cache into {cache_root}")
# restore_inductor_cache_from_hub("alexnasa/humo-compiled")
def get_duration(prompt_text, steps, image_file, audio_file_path, max_duration, session_id, progress):
return calculate_required_time(steps, max_duration)
def calculate_required_time(steps, max_duration):
warmup_s = 50
max_duration_duration_mapping = {
20: 3,
45: 7,
70: 13,
95: 21,
}
# Humo 1.7
# max_duration_duration_mapping = {
# 20: 2,
# 45: 2,
# 70: 5,
# 95: 6,
# }
each_step_s = max_duration_duration_mapping[max_duration]
duration_s = (each_step_s * steps) + warmup_s
print(f'estimated duration:{duration_s}')
return int(duration_s)
def get_required_time_string(steps, max_duration):
duration_s = calculate_required_time(steps, max_duration)
duration_m = duration_s / 60
return f"<center>⌚ Zero GPU Required: ~{duration_s}.0s ({duration_m:.1f} mins)</center>"
def update_required_time(steps, max_duration):
return get_required_time_string(steps, max_duration)
def generate_scene(prompt_text, steps, image_paths, audio_file_path, max_duration = 3, session_id = None, progress=gr.Progress(),):
prompt_text_check = (prompt_text or "").strip()
if not prompt_text_check:
raise gr.Error("Please enter a prompt.")
if not audio_file_path and not image_paths:
raise gr.Error("Please provide a reference image or a lipsync audio.")
return run_pipeline(prompt_text, steps, image_paths, audio_file_path, max_duration, session_id, progress)
def upload_inductor_cache_to_hub(
repo_id: str,
path_in_repo: str = "inductor_cache",
repo_type: str = "model", # or "dataset" if you prefer
hf_token: str | None = None,
):
"""
Zips the current TorchInductor cache and uploads it to the given repo path.
Assumes the model was already run once with torch.compile() so the cache exists.
"""
cache_dir = Path(_inductor_cache_dir()).resolve()
if not cache_dir.exists():
raise FileNotFoundError(f"TorchInductor cache not found at {cache_dir}. "
"Run a compiled model once to populate it.")
# Create a zip archive of the entire cache directory
with tempfile.TemporaryDirectory() as tmpdir:
archive_base = Path(tmpdir) / "torch_compile_cache"
archive_path = shutil.make_archive(str(archive_base), "zip", root_dir=str(cache_dir))
archive_path = Path(archive_path)
# Upload to Hub
api = HfApi(token=hf_token)
api.create_repo(repo_id=repo_id, repo_type=repo_type, exist_ok=True)
# Put each artifact under path_in_repo, including a tiny metadata stamp for traceability
# Upload the zip
dest_path = f"{path_in_repo}/{archive_path.name}"
api.upload_file(
path_or_fileobj=str(archive_path),
path_in_repo=dest_path,
repo_id=repo_id,
repo_type=repo_type,
)
# Upload a small metadata file (optional but handy)
meta_txt = (
f"pytorch={torch.__version__}\n"
f"inductor_cache_dir={cache_dir}\n"
f"cuda_available={torch.cuda.is_available()}\n"
f"cuda_device={torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'cpu'}\n"
)
api.upload_file(
path_or_fileobj=meta_txt.encode(),
path_in_repo=f"{path_in_repo}/INDUCTOR_CACHE_METADATA.txt",
repo_id=repo_id,
repo_type=repo_type,
)
print("✔ Uploaded TorchInductor cache to the Hub.")
@spaces.GPU(duration=get_duration)
def run_pipeline(prompt_text, steps, image_paths, audio_file_path, max_duration = 3, session_id = None, progress=gr.Progress(),):
if session_id is None:
session_id = uuid.uuid4().hex
inference_mode = "TIA"
# Validate inputs
prompt_text = (prompt_text or "").strip()
if not prompt_text:
raise gr.Error("Please enter a prompt.")
if not audio_file_path and not image_paths:
raise gr.Error("Please provide a reference image or a lipsync audio.")
if not audio_file_path:
inference_mode = "TI"
audio_path = None
tmp_audio_path = None
else:
audio_path = audio_file_path if isinstance(audio_file_path, str) else getattr(audio_file_path, "name", str(audio_file_path))
if not image_paths:
inference_mode = "TA"
img_paths = None
else:
img_paths = [image_data[0] for image_data in image_paths]
print(f'{session_id} is using inference_mode:{inference_mode} with steps:{steps} with {max_duration} frames')
output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
os.makedirs(output_dir, exist_ok=True)
if audio_path:
def add_silence_to_audio_ffmpeg(audio_path, tmp_audio_path, silence_duration_s=0.5):
command = [
'ffmpeg',
'-i', audio_path,
'-f', 'lavfi',
'-t', str(silence_duration_s),
'-i', 'anullsrc=r=16000:cl=stereo',
'-filter_complex', '[1][0]concat=n=2:v=0:a=1[out]',
'-map', '[out]',
'-y', tmp_audio_path,
'-loglevel', 'quiet'
]
subprocess.run(command, check=True)
tmp_audio_path = os.path.join(output_dir, "tmp_audio.wav")
add_silence_to_audio_ffmpeg(audio_path, tmp_audio_path)
# Random filename
filename = f"gen_{uuid.uuid4().hex[:10]}"
width, height = 832, 480
runner.inference_loop(
prompt_text,
img_paths,
tmp_audio_path,
output_dir,
filename,
inference_mode,
width,
height,
steps,
frames = int(max_duration),
tea_cache_l1_thresh = 0.0,
progress_bar_cmd=progress
)
# Return resulting video path
video_path = os.path.join(output_dir, f"{filename}.mp4")
if os.path.exists(video_path):
# upload_inductor_cache_to_hub("alexnasa/humo-compiled")
return video_path
else:
candidates = [os.path.join(output_dir, f) for f in os.listdir(output_dir) if f.endswith(".mp4")]
if candidates:
return max(candidates, key=lambda p: os.path.getmtime(p))
return None
css = """
#col-container {
margin: 0 auto;
width: 100%;
max-width: 720px;
}
"""
def cleanup(request: gr.Request):
sid = request.session_hash
if sid:
d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid)
shutil.rmtree(d1, ignore_errors=True)
def start_session(request: gr.Request):
return request.session_hash
with gr.Blocks(css=css) as demo:
session_state = gr.State()
demo.load(start_session, outputs=[session_state])
with gr.Sidebar(width=400):
gr.HTML(
"""
<div style="text-align: center;">
<p style="font-size:16px; display: inline; margin: 0;">
<strong>HuMo</strong> – Human-Centric Video Generation via Collaborative Multi-Modal Conditioning
</p>
<a href="https://github.com/Phantom-video/HuMo" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
[Github]
</a>
</div>
"""
)
gr.Markdown("**REFERENCE IMAGES**")
img_input = gr.Gallery(
value=["./examples/ali.png"],
show_label=False,
label="",
interactive=True,
rows=1, columns=3, object_fit="contain", height="280",
file_types=['image']
)
gr.Markdown("**LIPSYNC AUDIO**")
audio_input = gr.Audio(
value="./examples/life.wav",
sources=["upload"],
show_label=False,
type="filepath",
)
gr.Markdown("**SETTINGS**")
default_steps = 10
default_max_duration = 45
max_duration = gr.Slider(minimum=45, maximum=95, value=default_max_duration, step=25, label="Frames")
steps_input = gr.Slider(minimum=10, maximum=50, value=default_steps, step=5, label="Diffusion Steps")
with gr.Column(elem_id="col-container"):
gr.HTML(
"""
<div style="text-align: center;">
<strong>HF Space by:</strong>
<a href="https://twitter.com/alexandernasa/" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
<img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow Me" alt="GitHub Repo">
</a>
</div>
"""
)
video_output = gr.Video(show_label=False)
gr.Markdown("<center><h2>PROMPT</h2></center>")
prompt_tb = gr.Textbox(
value="A handheld tracking shot follows a female warrior walking through a cave. Her determined eyes are locked straight ahead as she grips a blazing torch tightly in her hand. She speaks with intensity.",
show_label=False,
lines=5,
placeholder="Describe the scene and the person talking....",
)
gr.Markdown("")
time_required = gr.Markdown(get_required_time_string(default_steps, default_max_duration))
run_btn = gr.Button("🎬 Action", variant="primary")
gr.Examples(
examples=[
[
"A handheld tracking shot follows a female through a science lab. Her determined eyes are locked straight ahead. She is explaining something to someone standing opposite her",
10,
["./examples/naomi.png"],
"./examples/science.wav",
70,
],
[
"A handheld tracking shot follows a female warrior walking through a cave. Her determined eyes are locked straight ahead as she grips a blazing torch tightly in her hand. She speaks with intensity.",
10,
["./examples/ella.png"],
"./examples/dream.mp3",
45,
],
[
"A reddish-brown haired woman sits pensively against swirling blue-and-white brushstrokes, dressed in a blue coat and dark waistcoat. The artistic backdrop and her thoughtful pose evoke a Post-Impressionist style in a studio-like setting.",
10,
["./examples/art.png"],
"./examples/art.wav",
70,
],
[
"A handheld tracking shot follows a female warrior walking through a cave. Her determined eyes are locked straight ahead as she grips a blazing torch tightly in her hand. She speaks with intensity.",
10,
["./examples/ella.png"],
"./examples/dream.mp3",
95,
],
[
"A woman with long, wavy dark hair looking at a person sitting opposite her whilst holding a book, wearing a leather jacket, long-sleeved jacket with a semi purple color one seen on a photo. Warm, window-like light bathes her figure, highlighting the outfit's elegant design and her graceful movements.",
40,
["./examples/amber.png", "./examples/jacket.png"],
"./examples/fictional.wav",
70,
],
],
inputs=[prompt_tb, steps_input, img_input, audio_input, max_duration],
outputs=[video_output],
fn=run_pipeline,
cache_examples=True,
)
max_duration.change(update_required_time, [steps_input, max_duration], time_required)
steps_input.change(update_required_time, [steps_input, max_duration], time_required)
run_btn.click(
fn=generate_scene,
inputs=[prompt_tb, steps_input, img_input, audio_input, max_duration, session_state],
outputs=[video_output],
)
if __name__ == "__main__":
demo.unload(cleanup)
demo.queue()
demo.launch(ssr_mode=False)