Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -110,7 +110,7 @@ parser.add_argument("--checkpoint_path", type=str, default='./checkpoints/self_f
|
|
| 110 |
parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml', help="Path to the model config.")
|
| 111 |
parser.add_argument('--share', action='store_true', help="Create a public Gradio link.")
|
| 112 |
parser.add_argument('--trt', action='store_true', help="Use TensorRT optimized VAE decoder.")
|
| 113 |
-
parser.add_argument('--fps', type=float, default=
|
| 114 |
args = parser.parse_args()
|
| 115 |
|
| 116 |
gpu = "cuda"
|
|
@@ -257,7 +257,7 @@ pipeline.to(dtype=torch.float16).to(gpu)
|
|
| 257 |
|
| 258 |
@torch.no_grad()
|
| 259 |
@spaces.GPU
|
| 260 |
-
def video_generation_handler_streaming(prompt, seed=42, fps=
|
| 261 |
"""
|
| 262 |
Generator function that yields .ts video chunks using PyAV for streaming.
|
| 263 |
Now optimized for block-based processing.
|
|
@@ -277,14 +277,14 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
|
| 277 |
pipeline._initialize_kv_cache(1, torch.float16, device=gpu)
|
| 278 |
pipeline._initialize_crossattn_cache(1, torch.float16, device=gpu)
|
| 279 |
|
| 280 |
-
#
|
| 281 |
-
noise = torch.randn([1,
|
| 282 |
|
| 283 |
vae_cache, latents_cache = None, None
|
| 284 |
if not APP_STATE["current_use_taehv"] and not args.trt:
|
| 285 |
vae_cache = [c.to(device=gpu, dtype=torch.float16) for c in ZERO_VAE_CACHE]
|
| 286 |
|
| 287 |
-
num_blocks =
|
| 288 |
current_start_frame = 0
|
| 289 |
all_num_frames = [pipeline.num_frame_per_block] * num_blocks
|
| 290 |
|
|
@@ -369,7 +369,7 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
|
| 369 |
|
| 370 |
frame_status_html = (
|
| 371 |
f"<div style='padding: 10px; border: 1px solid #ddd; border-radius: 8px; font-family: sans-serif;'>"
|
| 372 |
-
f" <p style='margin: 0 0 8px 0; font-size: 16px; font-weight: bold;'
|
| 373 |
f" <div style='background: #e9ecef; border-radius: 4px; width: 100%; overflow: hidden;'>"
|
| 374 |
f" <div style='width: {total_progress:.1f}%; height: 20px; background-color: #0d6efd; transition: width 0.2s;'></div>"
|
| 375 |
f" </div>"
|
|
@@ -407,7 +407,7 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
|
| 407 |
current_start_frame += current_num_frames
|
| 408 |
|
| 409 |
# λ©λͺ¨λ¦¬ ν¨μ¨μ±μ μν GPU μΊμ μ 리
|
| 410 |
-
if idx < num_blocks - 1 and idx %
|
| 411 |
torch.cuda.empty_cache()
|
| 412 |
|
| 413 |
# Final completion status
|
|
@@ -456,7 +456,7 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
|
| 456 |
f" π Generated {total_frames_yielded} frames across {num_blocks} blocks ({video_duration:.1f} seconds)"
|
| 457 |
f" </p>"
|
| 458 |
f" <p style='margin: 0; color: #0f5132; font-size: 14px;'>"
|
| 459 |
-
f" π¬ Resolution: {all_frames_for_download[0].shape[1]}x{all_frames_for_download[0].shape[0]} β’ FPS: {fps} β’ Size: {file_size_mb:.1f} MB"
|
| 460 |
f" </p>"
|
| 461 |
f" <p style='margin: 8px 0 0 0; color: #0f5132; font-size: 13px; font-style: italic;'>"
|
| 462 |
f" πΎ Click the download button below to save your video!"
|
|
@@ -479,8 +479,8 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
|
| 479 |
|
| 480 |
# --- Gradio UI Layout ---
|
| 481 |
with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
|
| 482 |
-
gr.Markdown("# π Self-Forcing Video Generation
|
| 483 |
-
gr.Markdown("Real-time
|
| 484 |
|
| 485 |
with gr.Row():
|
| 486 |
with gr.Column(scale=2):
|
|
@@ -506,6 +506,7 @@ with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
|
|
| 506 |
)
|
| 507 |
|
| 508 |
gr.Markdown("### βοΈ Settings")
|
|
|
|
| 509 |
with gr.Row():
|
| 510 |
seed = gr.Number(
|
| 511 |
label="Seed",
|
|
@@ -515,12 +516,12 @@ with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
|
|
| 515 |
)
|
| 516 |
fps = gr.Slider(
|
| 517 |
label="Playback FPS",
|
| 518 |
-
minimum=
|
| 519 |
maximum=30,
|
| 520 |
value=args.fps,
|
| 521 |
step=1,
|
| 522 |
-
visible=
|
| 523 |
-
info="
|
| 524 |
)
|
| 525 |
|
| 526 |
with gr.Column(scale=3):
|
|
@@ -548,8 +549,9 @@ with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
|
|
| 548 |
|
| 549 |
# λ€μ΄λ‘λμ© νμΌ μΆλ ₯
|
| 550 |
download_file = gr.File(
|
| 551 |
-
label="π₯ Download Video",
|
| 552 |
-
visible=False
|
|
|
|
| 553 |
)
|
| 554 |
|
| 555 |
# Connect the generator to the streaming video
|
|
|
|
| 110 |
parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml', help="Path to the model config.")
|
| 111 |
parser.add_argument('--share', action='store_true', help="Create a public Gradio link.")
|
| 112 |
parser.add_argument('--trt', action='store_true', help="Use TensorRT optimized VAE decoder.")
|
| 113 |
+
parser.add_argument('--fps', type=float, default=12.0, help="Playback FPS for frame streaming.")
|
| 114 |
args = parser.parse_args()
|
| 115 |
|
| 116 |
gpu = "cuda"
|
|
|
|
| 257 |
|
| 258 |
@torch.no_grad()
|
| 259 |
@spaces.GPU
|
| 260 |
+
def video_generation_handler_streaming(prompt, seed=42, fps=12):
|
| 261 |
"""
|
| 262 |
Generator function that yields .ts video chunks using PyAV for streaming.
|
| 263 |
Now optimized for block-based processing.
|
|
|
|
| 277 |
pipeline._initialize_kv_cache(1, torch.float16, device=gpu)
|
| 278 |
pipeline._initialize_crossattn_cache(1, torch.float16, device=gpu)
|
| 279 |
|
| 280 |
+
# λ
Έμ΄μ¦ ν
μ ν¬κΈ°
|
| 281 |
+
noise = torch.randn([1, 21, 16, 60, 104], device=gpu, dtype=torch.float16, generator=rnd)
|
| 282 |
|
| 283 |
vae_cache, latents_cache = None, None
|
| 284 |
if not APP_STATE["current_use_taehv"] and not args.trt:
|
| 285 |
vae_cache = [c.to(device=gpu, dtype=torch.float16) for c in ZERO_VAE_CACHE]
|
| 286 |
|
| 287 |
+
num_blocks = 7 # μλ μ€μ μΌλ‘ 볡μ
|
| 288 |
current_start_frame = 0
|
| 289 |
all_num_frames = [pipeline.num_frame_per_block] * num_blocks
|
| 290 |
|
|
|
|
| 369 |
|
| 370 |
frame_status_html = (
|
| 371 |
f"<div style='padding: 10px; border: 1px solid #ddd; border-radius: 8px; font-family: sans-serif;'>"
|
| 372 |
+
f" <p style='margin: 0 0 8px 0; font-size: 16px; font-weight: bold;'>π¬ Generating Video...</p>"
|
| 373 |
f" <div style='background: #e9ecef; border-radius: 4px; width: 100%; overflow: hidden;'>"
|
| 374 |
f" <div style='width: {total_progress:.1f}%; height: 20px; background-color: #0d6efd; transition: width 0.2s;'></div>"
|
| 375 |
f" </div>"
|
|
|
|
| 407 |
current_start_frame += current_num_frames
|
| 408 |
|
| 409 |
# λ©λͺ¨λ¦¬ ν¨μ¨μ±μ μν GPU μΊμ μ 리
|
| 410 |
+
if idx < num_blocks - 1 and idx % 3 == 2: # 3λΈλ‘λ§λ€ μΊμ μ 리
|
| 411 |
torch.cuda.empty_cache()
|
| 412 |
|
| 413 |
# Final completion status
|
|
|
|
| 456 |
f" π Generated {total_frames_yielded} frames across {num_blocks} blocks ({video_duration:.1f} seconds)"
|
| 457 |
f" </p>"
|
| 458 |
f" <p style='margin: 0; color: #0f5132; font-size: 14px;'>"
|
| 459 |
+
f" π¬ Resolution: {all_frames_for_download[0].shape[1]}x{all_frames_for_download[0].shape[0]} β’ FPS: {fps} β’ Duration: {video_duration:.1f}s β’ Size: {file_size_mb:.1f} MB"
|
| 460 |
f" </p>"
|
| 461 |
f" <p style='margin: 8px 0 0 0; color: #0f5132; font-size: 13px; font-style: italic;'>"
|
| 462 |
f" πΎ Click the download button below to save your video!"
|
|
|
|
| 479 |
|
| 480 |
# --- Gradio UI Layout ---
|
| 481 |
with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
|
| 482 |
+
gr.Markdown("# π Self-Forcing Video Generation")
|
| 483 |
+
gr.Markdown("Real-time video generation with distilled Wan2-1 1.3B | 5-6 seconds duration [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009)")
|
| 484 |
|
| 485 |
with gr.Row():
|
| 486 |
with gr.Column(scale=2):
|
|
|
|
| 506 |
)
|
| 507 |
|
| 508 |
gr.Markdown("### βοΈ Settings")
|
| 509 |
+
gr.Markdown("π‘ **Tip**: Adjust FPS to control video duration (8 FPS β ~10s, 10 FPS β ~8s, 12 FPS β ~6.8s, 15 FPS β ~5.4s)")
|
| 510 |
with gr.Row():
|
| 511 |
seed = gr.Number(
|
| 512 |
label="Seed",
|
|
|
|
| 516 |
)
|
| 517 |
fps = gr.Slider(
|
| 518 |
label="Playback FPS",
|
| 519 |
+
minimum=8,
|
| 520 |
maximum=30,
|
| 521 |
value=args.fps,
|
| 522 |
step=1,
|
| 523 |
+
visible=True,
|
| 524 |
+
info="Lower FPS = longer video duration"
|
| 525 |
)
|
| 526 |
|
| 527 |
with gr.Column(scale=3):
|
|
|
|
| 549 |
|
| 550 |
# λ€μ΄λ‘λμ© νμΌ μΆλ ₯
|
| 551 |
download_file = gr.File(
|
| 552 |
+
label="π₯ Download Generated Video",
|
| 553 |
+
visible=False,
|
| 554 |
+
elem_id="download_file"
|
| 555 |
)
|
| 556 |
|
| 557 |
# Connect the generator to the streaming video
|