Spaces:
Runtime error
Runtime error
Commit
·
5359939
1
Parent(s):
eb1feee
finish demo
Browse files- app.py +549 -49
- camera_pose.py +94 -0
- history_guidance.py +24 -0
app.py
CHANGED
|
@@ -6,19 +6,20 @@ import gradio as gr
|
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
| 8 |
from torchvision.datasets.utils import download_and_extract_archive
|
| 9 |
-
from
|
| 10 |
from omegaconf import OmegaConf
|
| 11 |
from algorithms.dfot import DFoTVideoPose
|
| 12 |
-
from
|
| 13 |
from utils.ckpt_utils import download_pretrained
|
| 14 |
-
from utils.huggingface_utils import download_from_hf
|
| 15 |
from datasets.video.utils.io import read_video
|
| 16 |
-
from datasets.video import RealEstate10KAdvancedVideoDataset
|
| 17 |
from export import export_to_video, export_to_gif, export_images_to_gif
|
|
|
|
|
|
|
| 18 |
|
| 19 |
DATASET_URL = "https://huggingface.co/kiwhansong/DFoT/resolve/main/datasets/RealEstate10K_Tiny.tar.gz"
|
| 20 |
DATASET_DIR = Path("data/real-estate-10k-tiny")
|
| 21 |
-
LONG_LENGTH =
|
|
|
|
| 22 |
|
| 23 |
if not DATASET_DIR.exists():
|
| 24 |
DATASET_DIR.mkdir(parents=True)
|
|
@@ -69,8 +70,8 @@ dfot.to("cuda")
|
|
| 69 |
|
| 70 |
def prepare_long_gt_video(idx: int):
|
| 71 |
video = video_list[idx]
|
| 72 |
-
indices = torch.linspace(0, video.size(0) - 1,
|
| 73 |
-
return export_to_video(video[indices], fps=
|
| 74 |
|
| 75 |
|
| 76 |
def prepare_short_gt_video(idx: int):
|
|
@@ -104,7 +105,7 @@ def single_image_to_long_video(
|
|
| 104 |
xs = video[indices].unsqueeze(0).to("cuda")
|
| 105 |
conditions = poses[indices].unsqueeze(0).to("cuda")
|
| 106 |
dfot.cfg.tasks.prediction.history_guidance.guidance_scale = guidance_scale
|
| 107 |
-
dfot.cfg.tasks.prediction.keyframe_density =
|
| 108 |
# dfot.cfg.tasks.interpolation.history_guidance.guidance_scale = guidance_scale
|
| 109 |
gen_video = dfot._unnormalize_x(
|
| 110 |
dfot._predict_videos(
|
|
@@ -151,6 +152,228 @@ def any_images_to_short_video(
|
|
| 151 |
return video_to_gif_and_images([image for image in gen_video], list(range(8)))
|
| 152 |
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
# Create the Gradio Blocks
|
| 155 |
with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
|
| 156 |
gr.HTML(
|
|
@@ -160,6 +383,21 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
|
|
| 160 |
font-size: 16px !important;
|
| 161 |
font-weight: bold;
|
| 162 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
</style>
|
| 164 |
"""
|
| 165 |
)
|
|
@@ -169,14 +407,29 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
|
|
| 169 |
"### Official Interactive Demo for [_History-guided Video Diffusion_](todo)"
|
| 170 |
)
|
| 171 |
with gr.Row():
|
| 172 |
-
gr.Button(value="🌐 Website", link="https://boyuan.space/history-guidance")
|
| 173 |
-
gr.Button(value="📄 Paper", link="https://arxiv.org/abs/2502.06764")
|
| 174 |
gr.Button(
|
| 175 |
-
value="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
link="https://github.com/kwsong0113/diffusion-forcing-transformer",
|
|
|
|
|
|
|
| 177 |
)
|
| 178 |
gr.Button(
|
| 179 |
-
value="
|
|
|
|
|
|
|
|
|
|
| 180 |
)
|
| 181 |
|
| 182 |
with gr.Accordion("Troubleshooting: Not Working or Too Slow?", open=False):
|
|
@@ -187,7 +440,6 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
|
|
| 187 |
"""
|
| 188 |
)
|
| 189 |
|
| 190 |
-
|
| 191 |
with gr.Tab("Any # of Images → Short Video", id="task-1"):
|
| 192 |
gr.Markdown(
|
| 193 |
"""
|
|
@@ -225,7 +477,7 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
|
|
| 225 |
def update_selection(selection: gr.SelectData):
|
| 226 |
return selection.index
|
| 227 |
|
| 228 |
-
demo1_scene_select_button = gr.Button("Select Scene")
|
| 229 |
|
| 230 |
@demo1_scene_select_button.click(
|
| 231 |
inputs=demo1_selected_scene_index, outputs=demo1_stage
|
|
@@ -257,7 +509,7 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
|
|
| 257 |
choices=[(f"t={i}", i) for i in range(8)],
|
| 258 |
value=[],
|
| 259 |
)
|
| 260 |
-
demo1_image_select_button = gr.Button("Select Input Images")
|
| 261 |
|
| 262 |
@demo1_image_select_button.click(
|
| 263 |
inputs=[demo1_selector],
|
|
@@ -304,7 +556,7 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
|
|
| 304 |
info="Without history guidance: 1.0; Recommended: 4.0",
|
| 305 |
interactive=True,
|
| 306 |
)
|
| 307 |
-
gr.Button("Generate Video").click(
|
| 308 |
fn=any_images_to_short_video,
|
| 309 |
inputs=[
|
| 310 |
demo1_selected_scene_index,
|
|
@@ -316,9 +568,9 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
|
|
| 316 |
|
| 317 |
with gr.Tab("Single Image → Long Video", id="task-2"):
|
| 318 |
gr.Markdown(
|
| 319 |
-
"""
|
| 320 |
-
## Demo 2: Single Image → Long
|
| 321 |
-
> #### _Diffusion Forcing Transformer, with History Guidance,
|
| 322 |
"""
|
| 323 |
)
|
| 324 |
|
|
@@ -344,7 +596,7 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
|
|
| 344 |
def update_selection(selection: gr.SelectData):
|
| 345 |
return selection.index
|
| 346 |
|
| 347 |
-
demo2_select_button = gr.Button("Select Input Image")
|
| 348 |
|
| 349 |
@demo2_select_button.click(
|
| 350 |
inputs=demo2_selected_index, outputs=demo2_stage
|
|
@@ -369,49 +621,297 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
|
|
| 369 |
label="Ground Truth Video",
|
| 370 |
width=256,
|
| 371 |
height=256,
|
|
|
|
|
|
|
| 372 |
)
|
| 373 |
demo2_video = gr.Video(
|
| 374 |
-
label="Generated Video",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
)
|
| 376 |
|
| 377 |
-
|
| 378 |
-
|
| 379 |
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
)
|
| 389 |
-
|
|
|
|
|
|
|
| 390 |
minimum=2,
|
| 391 |
maximum=10,
|
| 392 |
-
value=
|
| 393 |
step=1,
|
| 394 |
-
label="
|
| 395 |
-
info=f"A {LONG_LENGTH}-second video will be generated at this FPS; Decrease for faster generation; Increase for a smoother video",
|
| 396 |
interactive=True,
|
| 397 |
)
|
| 398 |
-
gr.Button("
|
| 399 |
-
fn=
|
| 400 |
-
inputs=[
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
|
|
|
|
|
|
|
|
|
| 404 |
],
|
| 405 |
-
outputs=demo2_video,
|
| 406 |
)
|
| 407 |
|
| 408 |
-
with gr.Tab("Single Image → Extremely Long Video", id="task-3"):
|
| 409 |
-
gr.Markdown(
|
| 410 |
-
"""
|
| 411 |
-
## Demo 3: Single Image → Extremely Long Video
|
| 412 |
-
> #### _TODO._
|
| 413 |
-
"""
|
| 414 |
-
)
|
| 415 |
|
| 416 |
if __name__ == "__main__":
|
| 417 |
demo.launch()
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
| 8 |
from torchvision.datasets.utils import download_and_extract_archive
|
| 9 |
+
from einops import repeat
|
| 10 |
from omegaconf import OmegaConf
|
| 11 |
from algorithms.dfot import DFoTVideoPose
|
| 12 |
+
from history_guidance import HistoryGuidance
|
| 13 |
from utils.ckpt_utils import download_pretrained
|
|
|
|
| 14 |
from datasets.video.utils.io import read_video
|
|
|
|
| 15 |
from export import export_to_video, export_to_gif, export_images_to_gif
|
| 16 |
+
from camera_pose import extend_poses, CameraPose
|
| 17 |
+
from scipy.spatial.transform import Rotation, Slerp
|
| 18 |
|
| 19 |
DATASET_URL = "https://huggingface.co/kiwhansong/DFoT/resolve/main/datasets/RealEstate10K_Tiny.tar.gz"
|
| 20 |
DATASET_DIR = Path("data/real-estate-10k-tiny")
|
| 21 |
+
LONG_LENGTH = 10 # seconds
|
| 22 |
+
NAVIGATION_FPS = 3
|
| 23 |
|
| 24 |
if not DATASET_DIR.exists():
|
| 25 |
DATASET_DIR.mkdir(parents=True)
|
|
|
|
| 70 |
|
| 71 |
def prepare_long_gt_video(idx: int):
|
| 72 |
video = video_list[idx]
|
| 73 |
+
indices = torch.linspace(0, video.size(0) - 1, 200, dtype=torch.long)
|
| 74 |
+
return export_to_video(video[indices], fps=200 // LONG_LENGTH)
|
| 75 |
|
| 76 |
|
| 77 |
def prepare_short_gt_video(idx: int):
|
|
|
|
| 105 |
xs = video[indices].unsqueeze(0).to("cuda")
|
| 106 |
conditions = poses[indices].unsqueeze(0).to("cuda")
|
| 107 |
dfot.cfg.tasks.prediction.history_guidance.guidance_scale = guidance_scale
|
| 108 |
+
dfot.cfg.tasks.prediction.keyframe_density = 12 / (fps * LONG_LENGTH)
|
| 109 |
# dfot.cfg.tasks.interpolation.history_guidance.guidance_scale = guidance_scale
|
| 110 |
gen_video = dfot._unnormalize_x(
|
| 111 |
dfot._predict_videos(
|
|
|
|
| 152 |
return video_to_gif_and_images([image for image in gen_video], list(range(8)))
|
| 153 |
|
| 154 |
|
| 155 |
+
class CustomProgressBar:
|
| 156 |
+
def __init__(self, pbar):
|
| 157 |
+
self.pbar = pbar
|
| 158 |
+
|
| 159 |
+
def set_postfix(self, **kwargs):
|
| 160 |
+
pass
|
| 161 |
+
|
| 162 |
+
def __getattr__(self, attr):
|
| 163 |
+
return getattr(self.pbar, attr)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
@torch.autocast("cuda")
|
| 167 |
+
@torch.no_grad()
|
| 168 |
+
def navigate_video(
|
| 169 |
+
video: torch.Tensor,
|
| 170 |
+
poses: torch.Tensor,
|
| 171 |
+
x_angle: float,
|
| 172 |
+
y_angle: float,
|
| 173 |
+
distance: float,
|
| 174 |
+
):
|
| 175 |
+
n_context_frames = min(len(video), 4)
|
| 176 |
+
n_prediction_frames = 8 - n_context_frames
|
| 177 |
+
pbar = CustomProgressBar(
|
| 178 |
+
gr.Progress(track_tqdm=True).tqdm(
|
| 179 |
+
iterable=None,
|
| 180 |
+
desc=f"Predicting next {n_prediction_frames} frames",
|
| 181 |
+
total=dfot.sampling_timesteps,
|
| 182 |
+
)
|
| 183 |
+
)
|
| 184 |
+
xs = dfot._normalize_x(video.clone().unsqueeze(0).to("cuda"))
|
| 185 |
+
conditions = poses.clone().unsqueeze(0).to("cuda")
|
| 186 |
+
conditions = extend_poses(
|
| 187 |
+
conditions,
|
| 188 |
+
n=n_prediction_frames,
|
| 189 |
+
x_angle=x_angle,
|
| 190 |
+
y_angle=y_angle,
|
| 191 |
+
distance=distance,
|
| 192 |
+
)
|
| 193 |
+
context_mask = (
|
| 194 |
+
torch.cat(
|
| 195 |
+
[
|
| 196 |
+
torch.ones(1, n_context_frames) * (1 if n_context_frames == 1 else 2),
|
| 197 |
+
torch.zeros(1, n_prediction_frames),
|
| 198 |
+
],
|
| 199 |
+
dim=-1,
|
| 200 |
+
)
|
| 201 |
+
.long()
|
| 202 |
+
.to("cuda")
|
| 203 |
+
)
|
| 204 |
+
next_video = (
|
| 205 |
+
dfot._unnormalize_x(
|
| 206 |
+
dfot._sample_sequence(
|
| 207 |
+
batch_size=1,
|
| 208 |
+
context=torch.cat(
|
| 209 |
+
[
|
| 210 |
+
xs[:, -n_context_frames:],
|
| 211 |
+
torch.zeros(
|
| 212 |
+
1,
|
| 213 |
+
n_prediction_frames,
|
| 214 |
+
*xs.shape[2:],
|
| 215 |
+
device=xs.device,
|
| 216 |
+
dtype=xs.dtype,
|
| 217 |
+
),
|
| 218 |
+
],
|
| 219 |
+
dim=1,
|
| 220 |
+
),
|
| 221 |
+
context_mask=context_mask,
|
| 222 |
+
conditions=conditions[:, -8:],
|
| 223 |
+
history_guidance=HistoryGuidance.smart(
|
| 224 |
+
x_angle=x_angle,
|
| 225 |
+
y_angle=y_angle,
|
| 226 |
+
distance=distance,
|
| 227 |
+
visualize=False,
|
| 228 |
+
),
|
| 229 |
+
pbar=pbar,
|
| 230 |
+
)[0]
|
| 231 |
+
)[0][n_context_frames:]
|
| 232 |
+
.detach()
|
| 233 |
+
.cpu()
|
| 234 |
+
)
|
| 235 |
+
gen_video = torch.cat([video, next_video], dim=0)
|
| 236 |
+
poses = conditions[0]
|
| 237 |
+
|
| 238 |
+
images = (gen_video.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8).numpy()
|
| 239 |
+
|
| 240 |
+
return (
|
| 241 |
+
gen_video,
|
| 242 |
+
poses,
|
| 243 |
+
images[-1],
|
| 244 |
+
export_to_video(gen_video, fps=NAVIGATION_FPS),
|
| 245 |
+
[(image, f"t={i}") for i, image in enumerate(images)],
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
def undo_navigation(
|
| 249 |
+
video: torch.Tensor,
|
| 250 |
+
poses: torch.Tensor,
|
| 251 |
+
):
|
| 252 |
+
if len(video) >= 8:
|
| 253 |
+
video = video[:-4]
|
| 254 |
+
poses = poses[:-4]
|
| 255 |
+
else:
|
| 256 |
+
gr.Warning("You have no moves left to undo!")
|
| 257 |
+
images = (video.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8).numpy()
|
| 258 |
+
return (
|
| 259 |
+
video,
|
| 260 |
+
poses,
|
| 261 |
+
images[-1],
|
| 262 |
+
export_to_video(video, fps=NAVIGATION_FPS),
|
| 263 |
+
[(image, f"t={i}") for i, image in enumerate(images)],
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
def _interpolate_conditions(conditions, indices):
|
| 267 |
+
"""
|
| 268 |
+
Interpolate conditions to fill out missing frames
|
| 269 |
+
|
| 270 |
+
Aegs:
|
| 271 |
+
conditions (Tensor): conditions (B, T, C)
|
| 272 |
+
indices (Tensor): indices of keyframes (T')
|
| 273 |
+
"""
|
| 274 |
+
assert indices[0].item() == 0
|
| 275 |
+
assert indices[-1].item() == conditions.shape[1] - 1
|
| 276 |
+
|
| 277 |
+
indices = indices.cpu().numpy()
|
| 278 |
+
batch_size, n_tokens, _ = conditions.shape
|
| 279 |
+
t = np.linspace(0, n_tokens - 1, n_tokens)
|
| 280 |
+
|
| 281 |
+
key_conditions = conditions[:, indices]
|
| 282 |
+
poses = CameraPose.from_vectors(key_conditions)
|
| 283 |
+
extrinsics = poses.extrinsics().cpu().numpy()
|
| 284 |
+
ps = extrinsics[..., :3, 3]
|
| 285 |
+
rs = extrinsics[..., :3, :3].reshape(batch_size, -1, 3, 3)
|
| 286 |
+
|
| 287 |
+
interp_extrinsics = np.zeros((batch_size, n_tokens, 3, 4))
|
| 288 |
+
for i in range(batch_size):
|
| 289 |
+
slerp = Slerp(indices, Rotation.from_matrix(rs[i]))
|
| 290 |
+
interp_extrinsics[i, :, :3, :3] = slerp(t).as_matrix()
|
| 291 |
+
for j in range(3):
|
| 292 |
+
interp_extrinsics[i, :, j, 3] = np.interp(t, indices, ps[i, :, j])
|
| 293 |
+
interp_extrinsics = torch.from_numpy(interp_extrinsics.astype(np.float32))
|
| 294 |
+
interp_extrinsics = interp_extrinsics.to(conditions.device).flatten(2)
|
| 295 |
+
conditions = repeat(key_conditions[:, 0, :4], "b c -> b t c", t=n_tokens)
|
| 296 |
+
conditions = torch.cat([conditions.clone(), interp_extrinsics], dim=-1)
|
| 297 |
+
|
| 298 |
+
return conditions
|
| 299 |
+
|
| 300 |
+
@spaces.GPU(duration=300)
|
| 301 |
+
@torch.autocast("cuda")
|
| 302 |
+
@torch.no_grad()
|
| 303 |
+
def _interpolate_between(
|
| 304 |
+
xs: torch.Tensor,
|
| 305 |
+
conditions: torch.Tensor,
|
| 306 |
+
interpolation_factor: int,
|
| 307 |
+
progress=gr.Progress(track_tqdm=True),
|
| 308 |
+
):
|
| 309 |
+
l = xs.shape[1]
|
| 310 |
+
final_l = (l - 1) * interpolation_factor + 1
|
| 311 |
+
x_shape = xs.shape[2:]
|
| 312 |
+
context = torch.zeros(
|
| 313 |
+
(
|
| 314 |
+
1,
|
| 315 |
+
final_l,
|
| 316 |
+
*x_shape,
|
| 317 |
+
),
|
| 318 |
+
device=xs.device,
|
| 319 |
+
dtype=xs.dtype,
|
| 320 |
+
)
|
| 321 |
+
long_conditions = torch.zeros(
|
| 322 |
+
(1, final_l, *conditions.shape[2:]),
|
| 323 |
+
device=conditions.device,
|
| 324 |
+
dtype=conditions.dtype,
|
| 325 |
+
)
|
| 326 |
+
context_mask = torch.zeros(
|
| 327 |
+
(1, final_l),
|
| 328 |
+
device=xs.device,
|
| 329 |
+
dtype=torch.bool,
|
| 330 |
+
)
|
| 331 |
+
context_indices = torch.arange(
|
| 332 |
+
0, final_l, interpolation_factor, device=conditions.device
|
| 333 |
+
)
|
| 334 |
+
context[:, context_indices] = xs
|
| 335 |
+
long_conditions[:, context_indices] = conditions
|
| 336 |
+
context_mask[:, ::interpolation_factor] = True
|
| 337 |
+
long_conditions = _interpolate_conditions(
|
| 338 |
+
long_conditions,
|
| 339 |
+
context_indices,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
xs = dfot._interpolate_videos(
|
| 343 |
+
context,
|
| 344 |
+
context_mask,
|
| 345 |
+
conditions=long_conditions,
|
| 346 |
+
)
|
| 347 |
+
return xs, long_conditions
|
| 348 |
+
|
| 349 |
+
def smooth_navigation(
|
| 350 |
+
video: torch.Tensor,
|
| 351 |
+
poses: torch.Tensor,
|
| 352 |
+
interpolation_factor: int,
|
| 353 |
+
progress=gr.Progress(track_tqdm=True),
|
| 354 |
+
):
|
| 355 |
+
if len(video) < 8:
|
| 356 |
+
gr.Warning("Navigate first before applying temporal super-resolution!")
|
| 357 |
+
else:
|
| 358 |
+
video, poses = _interpolate_between(
|
| 359 |
+
dfot._normalize_x(video.clone().unsqueeze(0).to("cuda")),
|
| 360 |
+
poses.clone().unsqueeze(0).to("cuda"),
|
| 361 |
+
interpolation_factor,
|
| 362 |
+
)
|
| 363 |
+
video = dfot._unnormalize_x(video)[0].detach().cpu()
|
| 364 |
+
poses = poses[0]
|
| 365 |
+
images = (video.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8).numpy()
|
| 366 |
+
return (
|
| 367 |
+
video,
|
| 368 |
+
poses,
|
| 369 |
+
images[-1],
|
| 370 |
+
export_to_video(video, fps=NAVIGATION_FPS * interpolation_factor),
|
| 371 |
+
[(image, f"t={i}") for i, image in enumerate(images)],
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
|
| 377 |
# Create the Gradio Blocks
|
| 378 |
with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
|
| 379 |
gr.HTML(
|
|
|
|
| 383 |
font-size: 16px !important;
|
| 384 |
font-weight: bold;
|
| 385 |
}
|
| 386 |
+
#header-button .button-icon {
|
| 387 |
+
margin-right: 8px;
|
| 388 |
+
}
|
| 389 |
+
#basic-controls {
|
| 390 |
+
column-gap: 0px;
|
| 391 |
+
}
|
| 392 |
+
#basic-controls button {
|
| 393 |
+
border: 1px solid #e4e4e7;
|
| 394 |
+
}
|
| 395 |
+
#basic-controls-tab {
|
| 396 |
+
padding: 0px;
|
| 397 |
+
}
|
| 398 |
+
#advanced-controls-tab {
|
| 399 |
+
padding: 0px;
|
| 400 |
+
}
|
| 401 |
</style>
|
| 402 |
"""
|
| 403 |
)
|
|
|
|
| 407 |
"### Official Interactive Demo for [_History-guided Video Diffusion_](todo)"
|
| 408 |
)
|
| 409 |
with gr.Row():
|
|
|
|
|
|
|
| 410 |
gr.Button(
|
| 411 |
+
value="Website",
|
| 412 |
+
link="https://boyuan.space/history-guidance",
|
| 413 |
+
icon="https://simpleicons.org/icons/googlechrome.svg",
|
| 414 |
+
elem_id="header-button",
|
| 415 |
+
)
|
| 416 |
+
gr.Button(
|
| 417 |
+
value="Paper",
|
| 418 |
+
link="https://arxiv.org/abs/2502.06764",
|
| 419 |
+
icon="https://simpleicons.org/icons/arxiv.svg",
|
| 420 |
+
elem_id="header-button",
|
| 421 |
+
)
|
| 422 |
+
gr.Button(
|
| 423 |
+
value="Code",
|
| 424 |
link="https://github.com/kwsong0113/diffusion-forcing-transformer",
|
| 425 |
+
icon="https://simpleicons.org/icons/github.svg",
|
| 426 |
+
elem_id="header-button",
|
| 427 |
)
|
| 428 |
gr.Button(
|
| 429 |
+
value="Pretrained Models",
|
| 430 |
+
link="https://huggingface.co/kiwhansong/DFoT",
|
| 431 |
+
icon="https://simpleicons.org/icons/huggingface.svg",
|
| 432 |
+
elem_id="header-button",
|
| 433 |
)
|
| 434 |
|
| 435 |
with gr.Accordion("Troubleshooting: Not Working or Too Slow?", open=False):
|
|
|
|
| 440 |
"""
|
| 441 |
)
|
| 442 |
|
|
|
|
| 443 |
with gr.Tab("Any # of Images → Short Video", id="task-1"):
|
| 444 |
gr.Markdown(
|
| 445 |
"""
|
|
|
|
| 477 |
def update_selection(selection: gr.SelectData):
|
| 478 |
return selection.index
|
| 479 |
|
| 480 |
+
demo1_scene_select_button = gr.Button("Select Scene", variant="primary")
|
| 481 |
|
| 482 |
@demo1_scene_select_button.click(
|
| 483 |
inputs=demo1_selected_scene_index, outputs=demo1_stage
|
|
|
|
| 509 |
choices=[(f"t={i}", i) for i in range(8)],
|
| 510 |
value=[],
|
| 511 |
)
|
| 512 |
+
demo1_image_select_button = gr.Button("Select Input Images", variant="primary")
|
| 513 |
|
| 514 |
@demo1_image_select_button.click(
|
| 515 |
inputs=[demo1_selector],
|
|
|
|
| 556 |
info="Without history guidance: 1.0; Recommended: 4.0",
|
| 557 |
interactive=True,
|
| 558 |
)
|
| 559 |
+
gr.Button("Generate Video", variant="primary").click(
|
| 560 |
fn=any_images_to_short_video,
|
| 561 |
inputs=[
|
| 562 |
demo1_selected_scene_index,
|
|
|
|
| 568 |
|
| 569 |
with gr.Tab("Single Image → Long Video", id="task-2"):
|
| 570 |
gr.Markdown(
|
| 571 |
+
f"""
|
| 572 |
+
## Demo 2: Single Image → Long {LONG_LENGTH}-second Video
|
| 573 |
+
> #### _Diffusion Forcing Transformer, with History Guidance, generates long videos via sliding window rollouts and temporal super-resolution._
|
| 574 |
"""
|
| 575 |
)
|
| 576 |
|
|
|
|
| 596 |
def update_selection(selection: gr.SelectData):
|
| 597 |
return selection.index
|
| 598 |
|
| 599 |
+
demo2_select_button = gr.Button("Select Input Image", variant="primary")
|
| 600 |
|
| 601 |
@demo2_select_button.click(
|
| 602 |
inputs=demo2_selected_index, outputs=demo2_stage
|
|
|
|
| 621 |
label="Ground Truth Video",
|
| 622 |
width=256,
|
| 623 |
height=256,
|
| 624 |
+
autoplay=True,
|
| 625 |
+
loop=True,
|
| 626 |
)
|
| 627 |
demo2_video = gr.Video(
|
| 628 |
+
label="Generated Video",
|
| 629 |
+
width=256,
|
| 630 |
+
height=256,
|
| 631 |
+
autoplay=True,
|
| 632 |
+
loop=True,
|
| 633 |
+
show_share_button=True,
|
| 634 |
+
show_download_button=True,
|
| 635 |
)
|
| 636 |
|
| 637 |
+
with gr.Sidebar():
|
| 638 |
+
gr.Markdown("### Sampling Parameters")
|
| 639 |
|
| 640 |
+
demo2_guidance_scale = gr.Slider(
|
| 641 |
+
minimum=1,
|
| 642 |
+
maximum=6,
|
| 643 |
+
value=4,
|
| 644 |
+
step=0.5,
|
| 645 |
+
label="History Guidance Scale",
|
| 646 |
+
info="Without history guidance: 1.0; Recommended: 4.0",
|
| 647 |
+
interactive=True,
|
| 648 |
+
)
|
| 649 |
+
demo2_fps = gr.Slider(
|
| 650 |
+
minimum=4,
|
| 651 |
+
maximum=20,
|
| 652 |
+
value=8,
|
| 653 |
+
step=1,
|
| 654 |
+
label="FPS",
|
| 655 |
+
info=f"A {LONG_LENGTH}-second video will be generated at this FPS; Decrease for faster generation; Increase for a smoother video",
|
| 656 |
+
interactive=True,
|
| 657 |
+
)
|
| 658 |
+
gr.Button("Generate Video", variant="primary").click(
|
| 659 |
+
fn=single_image_to_long_video,
|
| 660 |
+
inputs=[
|
| 661 |
+
demo2_selected_index,
|
| 662 |
+
demo2_guidance_scale,
|
| 663 |
+
demo2_fps,
|
| 664 |
+
],
|
| 665 |
+
outputs=demo2_video,
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
with gr.Tab("Single Image → Endless Video Navigation", id="task-3"):
|
| 669 |
+
gr.Markdown(
|
| 670 |
+
"""
|
| 671 |
+
## Demo 3: Single Image → Extremely Long Video _(Navigate with Your Camera Movements!)_
|
| 672 |
+
> #### _History Guidance significantly improves quality and temporal consistency, enabling stable rollouts for extremely long videos._
|
| 673 |
+
"""
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
demo3_stage = gr.State(value="Selection")
|
| 677 |
+
demo3_selected_index = gr.State(value=None)
|
| 678 |
+
demo3_current_video = gr.State(value=None)
|
| 679 |
+
demo3_current_poses = gr.State(value=None)
|
| 680 |
+
|
| 681 |
+
@gr.render(inputs=[demo3_stage, demo3_selected_index])
|
| 682 |
+
def render_stage(s, idx):
|
| 683 |
+
match s:
|
| 684 |
+
case "Selection":
|
| 685 |
+
with gr.Group():
|
| 686 |
+
demo3_image_gallery = gr.Gallery(
|
| 687 |
+
height=300,
|
| 688 |
+
value=first_frame_list,
|
| 689 |
+
label="Select an Image to Start Navigation",
|
| 690 |
+
columns=[8],
|
| 691 |
+
selected_index=idx,
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
@demo3_image_gallery.select(
|
| 695 |
+
inputs=None, outputs=demo3_selected_index
|
| 696 |
+
)
|
| 697 |
+
def update_selection(selection: gr.SelectData):
|
| 698 |
+
return selection.index
|
| 699 |
+
|
| 700 |
+
demo3_select_button = gr.Button("Select Input Image", variant="primary")
|
| 701 |
+
|
| 702 |
+
@demo3_select_button.click(
|
| 703 |
+
inputs=demo3_selected_index,
|
| 704 |
+
outputs=[
|
| 705 |
+
demo3_stage,
|
| 706 |
+
demo3_current_video,
|
| 707 |
+
demo3_current_poses,
|
| 708 |
+
],
|
| 709 |
+
)
|
| 710 |
+
def move_to_generation(idx: int):
|
| 711 |
+
if idx is None:
|
| 712 |
+
gr.Warning("Image not selected!")
|
| 713 |
+
return "Selection", None, None
|
| 714 |
+
else:
|
| 715 |
+
return (
|
| 716 |
+
"Generation",
|
| 717 |
+
video_list[idx][:1],
|
| 718 |
+
poses_list[idx][:1],
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
case "Generation":
|
| 722 |
+
with gr.Row():
|
| 723 |
+
demo3_current_view = gr.Image(
|
| 724 |
+
value=first_frame_list[idx],
|
| 725 |
+
label="Current View",
|
| 726 |
+
width=256,
|
| 727 |
+
height=256,
|
| 728 |
+
)
|
| 729 |
+
demo3_video = gr.Video(
|
| 730 |
+
label="Generated Video",
|
| 731 |
+
width=256,
|
| 732 |
+
height=256,
|
| 733 |
+
autoplay=True,
|
| 734 |
+
loop=True,
|
| 735 |
+
show_share_button=True,
|
| 736 |
+
show_download_button=True,
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
demo3_generated_gallery = gr.Gallery(
|
| 740 |
+
value=[],
|
| 741 |
+
label="Generated Frames",
|
| 742 |
+
columns=[8],
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
with gr.Sidebar():
|
| 746 |
+
gr.Markdown(
|
| 747 |
+
"""
|
| 748 |
+
### Let's Navigate!
|
| 749 |
+
**The model will predict the next few frames based on your camera movements. Repeat the process to navigate through the scene.** The most suitable history guidance scheme will be automatically selected based on your camera movements.
|
| 750 |
+
"""
|
| 751 |
+
)
|
| 752 |
+
with gr.Tab("Basic", elem_id="basic-controls-tab"):
|
| 753 |
+
with gr.Group():
|
| 754 |
+
gr.Markdown("_**Select a direction to move:**_")
|
| 755 |
+
with gr.Row(elem_id="basic-controls"):
|
| 756 |
+
gr.Button("↰-60°\nTurn", size="sm", min_width=0, variant="primary").click(
|
| 757 |
+
fn=partial(
|
| 758 |
+
navigate_video,
|
| 759 |
+
x_angle=0,
|
| 760 |
+
y_angle=-60,
|
| 761 |
+
distance=0,
|
| 762 |
+
),
|
| 763 |
+
inputs=[demo3_current_video, demo3_current_poses],
|
| 764 |
+
outputs=[
|
| 765 |
+
demo3_current_video,
|
| 766 |
+
demo3_current_poses,
|
| 767 |
+
demo3_current_view,
|
| 768 |
+
demo3_video,
|
| 769 |
+
demo3_generated_gallery,
|
| 770 |
+
],
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
gr.Button("↖-30°\nVeer", size="sm", min_width=0, variant="primary").click(
|
| 774 |
+
fn=partial(
|
| 775 |
+
navigate_video,
|
| 776 |
+
x_angle=0,
|
| 777 |
+
y_angle=-30,
|
| 778 |
+
distance=50,
|
| 779 |
+
),
|
| 780 |
+
inputs=[demo3_current_video, demo3_current_poses],
|
| 781 |
+
outputs=[
|
| 782 |
+
demo3_current_video,
|
| 783 |
+
demo3_current_poses,
|
| 784 |
+
demo3_current_view,
|
| 785 |
+
demo3_video,
|
| 786 |
+
demo3_generated_gallery,
|
| 787 |
+
],
|
| 788 |
+
)
|
| 789 |
+
|
| 790 |
+
gr.Button("↑0°\nAhead", size="sm", min_width=0, variant="primary").click(
|
| 791 |
+
fn=partial(
|
| 792 |
+
navigate_video,
|
| 793 |
+
x_angle=0,
|
| 794 |
+
y_angle=0,
|
| 795 |
+
distance=100,
|
| 796 |
+
),
|
| 797 |
+
inputs=[demo3_current_video, demo3_current_poses],
|
| 798 |
+
outputs=[
|
| 799 |
+
demo3_current_video,
|
| 800 |
+
demo3_current_poses,
|
| 801 |
+
demo3_current_view,
|
| 802 |
+
demo3_video,
|
| 803 |
+
demo3_generated_gallery,
|
| 804 |
+
],
|
| 805 |
+
)
|
| 806 |
+
gr.Button("↗30°\nVeer", size="sm", min_width=0, variant="primary").click(
|
| 807 |
+
fn=partial(
|
| 808 |
+
navigate_video,
|
| 809 |
+
x_angle=0,
|
| 810 |
+
y_angle=30,
|
| 811 |
+
distance=50,
|
| 812 |
+
),
|
| 813 |
+
inputs=[demo3_current_video, demo3_current_poses],
|
| 814 |
+
outputs=[
|
| 815 |
+
demo3_current_video,
|
| 816 |
+
demo3_current_poses,
|
| 817 |
+
demo3_current_view,
|
| 818 |
+
demo3_video,
|
| 819 |
+
demo3_generated_gallery,
|
| 820 |
+
],
|
| 821 |
+
)
|
| 822 |
+
gr.Button("↱\n60° Turn", size="sm", min_width=0, variant="primary").click(
|
| 823 |
+
fn=partial(
|
| 824 |
+
navigate_video,
|
| 825 |
+
x_angle=0,
|
| 826 |
+
y_angle=60,
|
| 827 |
+
distance=0,
|
| 828 |
+
),
|
| 829 |
+
inputs=[demo3_current_video, demo3_current_poses],
|
| 830 |
+
outputs=[
|
| 831 |
+
demo3_current_video,
|
| 832 |
+
demo3_current_poses,
|
| 833 |
+
demo3_current_view,
|
| 834 |
+
demo3_video,
|
| 835 |
+
demo3_generated_gallery,
|
| 836 |
+
],
|
| 837 |
+
)
|
| 838 |
+
with gr.Tab("Advanced", elem_id="advanced-controls-tab"):
|
| 839 |
+
with gr.Group():
|
| 840 |
+
gr.Markdown("_**Select angles and distance:**_")
|
| 841 |
+
|
| 842 |
+
demo3_y_angle = gr.Slider(
|
| 843 |
+
minimum=-90,
|
| 844 |
+
maximum=90,
|
| 845 |
+
value=0,
|
| 846 |
+
step=10,
|
| 847 |
+
label="Horizontal Angle",
|
| 848 |
+
interactive=True,
|
| 849 |
+
)
|
| 850 |
+
demo3_x_angle = gr.Slider(
|
| 851 |
+
minimum=-40,
|
| 852 |
+
maximum=40,
|
| 853 |
+
value=0,
|
| 854 |
+
step=10,
|
| 855 |
+
label="Vertical Angle",
|
| 856 |
+
interactive=True,
|
| 857 |
+
)
|
| 858 |
+
demo3_distance = gr.Slider(
|
| 859 |
+
minimum=0,
|
| 860 |
+
maximum=200,
|
| 861 |
+
value=100,
|
| 862 |
+
step=10,
|
| 863 |
+
label="Distance",
|
| 864 |
+
interactive=True,
|
| 865 |
+
)
|
| 866 |
+
|
| 867 |
+
gr.Button("Generate Next Move", variant="primary").click(
|
| 868 |
+
fn=partial(
|
| 869 |
+
navigate_video,
|
| 870 |
+
),
|
| 871 |
+
inputs=[demo3_current_video, demo3_current_poses, demo3_x_angle, demo3_y_angle, demo3_distance],
|
| 872 |
+
outputs=[
|
| 873 |
+
demo3_current_video,
|
| 874 |
+
demo3_current_poses,
|
| 875 |
+
demo3_current_view,
|
| 876 |
+
demo3_video,
|
| 877 |
+
demo3_generated_gallery,
|
| 878 |
+
],
|
| 879 |
+
)
|
| 880 |
+
with gr.Group():
|
| 881 |
+
gr.Markdown("_You can always undo your last move:_")
|
| 882 |
+
gr.Button("Undo Last Move", variant="huggingface").click(
|
| 883 |
+
fn=undo_navigation,
|
| 884 |
+
inputs=[demo3_current_video, demo3_current_poses],
|
| 885 |
+
outputs=[
|
| 886 |
+
demo3_current_video,
|
| 887 |
+
demo3_current_poses,
|
| 888 |
+
demo3_current_view,
|
| 889 |
+
demo3_video,
|
| 890 |
+
demo3_generated_gallery,
|
| 891 |
+
],
|
| 892 |
)
|
| 893 |
+
with gr.Group():
|
| 894 |
+
gr.Markdown("_At the end, apply temporal super-resolution to obtain a smoother video:_")
|
| 895 |
+
demo3_interpolation_factor=gr.Slider(
|
| 896 |
minimum=2,
|
| 897 |
maximum=10,
|
| 898 |
+
value=2,
|
| 899 |
step=1,
|
| 900 |
+
label="Interpolation Factor",
|
|
|
|
| 901 |
interactive=True,
|
| 902 |
)
|
| 903 |
+
gr.Button("Smooth Out Video", variant="huggingface").click(
|
| 904 |
+
fn=smooth_navigation,
|
| 905 |
+
inputs=[demo3_current_video, demo3_current_poses, demo3_interpolation_factor],
|
| 906 |
+
outputs=[
|
| 907 |
+
demo3_current_video,
|
| 908 |
+
demo3_current_poses,
|
| 909 |
+
demo3_current_view,
|
| 910 |
+
demo3_video,
|
| 911 |
+
demo3_generated_gallery,
|
| 912 |
],
|
|
|
|
| 913 |
)
|
| 914 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 915 |
|
| 916 |
if __name__ == "__main__":
|
| 917 |
demo.launch()
|
camera_pose.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from utils.geometry_utils import CameraPose
|
| 3 |
+
from einops import rearrange, repeat
|
| 4 |
+
import math
|
| 5 |
+
import roma
|
| 6 |
+
|
| 7 |
+
class ControllableCameraPose(CameraPose):
|
| 8 |
+
def to_vectors(self) -> torch.Tensor:
|
| 9 |
+
"""
|
| 10 |
+
Returns the raw camera poses.
|
| 11 |
+
Returns:
|
| 12 |
+
torch.Tensor: The raw camera poses. Shape (B, T, 4 + 12).
|
| 13 |
+
"""
|
| 14 |
+
RT = torch.cat([self._R, rearrange(self._T, "b t i -> b t i 1")], dim=-1)
|
| 15 |
+
return torch.cat([self._K, rearrange(RT, "b t i j -> b t (i j)")], dim=-1)
|
| 16 |
+
|
| 17 |
+
def extend(
|
| 18 |
+
self,
|
| 19 |
+
num_frames: int,
|
| 20 |
+
x_angle: float = 0.0,
|
| 21 |
+
y_angle: float = 0.0,
|
| 22 |
+
distance: float = 100.0,
|
| 23 |
+
) -> None:
|
| 24 |
+
"""
|
| 25 |
+
Extends the camera poses.
|
| 26 |
+
Let's say 0 degree is the direction of the last camera pose.
|
| 27 |
+
Smoothly Move & rotate the camera poses in the direction of the given angle (clockwise) in a 2D plane.
|
| 28 |
+
Args:
|
| 29 |
+
num_frames (int): The number of frames to extend.
|
| 30 |
+
x_angle (float): The angle to extend. The angle is in degrees.
|
| 31 |
+
y_angle (float): The angle to extend. The angle is in degrees.
|
| 32 |
+
"""
|
| 33 |
+
MOVING_SCALE = 0.5 * distance / 100
|
| 34 |
+
self._normalize_by(self._R[:, -1], self._T[:, -1])
|
| 35 |
+
|
| 36 |
+
# first compute relative poses for the final n + num_frames th frame
|
| 37 |
+
|
| 38 |
+
# compute the rotation matrix for the given angle
|
| 39 |
+
R_final = roma.euler_to_rotmat(
|
| 40 |
+
convention="xyz",
|
| 41 |
+
angles=torch.tensor(
|
| 42 |
+
[-x_angle, -y_angle, 0], device=self._R.device, dtype=torch.float32
|
| 43 |
+
),
|
| 44 |
+
degrees=True,
|
| 45 |
+
dtype=torch.float32,
|
| 46 |
+
device=self._R.device,
|
| 47 |
+
).unsqueeze(0)
|
| 48 |
+
|
| 49 |
+
# compute the translation vector for the given angle
|
| 50 |
+
T_final = torch.tensor(
|
| 51 |
+
[
|
| 52 |
+
-MOVING_SCALE * num_frames * math.sin(math.radians(y_angle)),
|
| 53 |
+
MOVING_SCALE * num_frames * math.sin(math.radians(x_angle)),
|
| 54 |
+
-MOVING_SCALE * num_frames * math.cos(math.radians(y_angle)),
|
| 55 |
+
],
|
| 56 |
+
device=self._T.device,
|
| 57 |
+
dtype=self._T.dtype,
|
| 58 |
+
).unsqueeze(0)
|
| 59 |
+
|
| 60 |
+
R = torch.cat(
|
| 61 |
+
[self._R, repeat(R_final, "b i j -> b t i j", t=num_frames).clone()], dim=1
|
| 62 |
+
)
|
| 63 |
+
T = torch.cat(
|
| 64 |
+
[self._T, repeat(T_final, "b i -> b t i", t=num_frames).clone()], dim=1
|
| 65 |
+
)
|
| 66 |
+
K = torch.cat(
|
| 67 |
+
[self._K, repeat(self._K[:, -1], "b i -> b t i", t=num_frames).clone()],
|
| 68 |
+
dim=1,
|
| 69 |
+
)
|
| 70 |
+
self._R = R
|
| 71 |
+
self._T = T
|
| 72 |
+
self._K = K
|
| 73 |
+
# interpolate all frames btwn the last frame and the final frame
|
| 74 |
+
self.replace_with_interpolation(
|
| 75 |
+
torch.cat(
|
| 76 |
+
[
|
| 77 |
+
torch.zeros_like(self._T[:, :-num_frames, 0]),
|
| 78 |
+
torch.ones_like(self._T[:, -num_frames:-1, 0]),
|
| 79 |
+
torch.zeros_like(self._T[:, -1:, 0]),
|
| 80 |
+
],
|
| 81 |
+
dim=-1,
|
| 82 |
+
).bool()
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def extend_poses(
|
| 86 |
+
conditions: torch.Tensor,
|
| 87 |
+
n: int,
|
| 88 |
+
x_angle: float = 0.0,
|
| 89 |
+
y_angle: float = 0.0,
|
| 90 |
+
distance: float = 0.0,
|
| 91 |
+
) -> torch.Tensor:
|
| 92 |
+
poses = ControllableCameraPose.from_vectors(conditions)
|
| 93 |
+
poses.extend(n, x_angle, y_angle, distance)
|
| 94 |
+
return poses.to_vectors()
|
history_guidance.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from algorithms.dfot.history_guidance import HistoryGuidance as _HistoryGuidance
|
| 2 |
+
|
| 3 |
+
class HistoryGuidance(_HistoryGuidance):
|
| 4 |
+
@classmethod
|
| 5 |
+
def smart(
|
| 6 |
+
cls,
|
| 7 |
+
x_angle: float,
|
| 8 |
+
y_angle: float,
|
| 9 |
+
distance: float,
|
| 10 |
+
visualize: bool = False,
|
| 11 |
+
):
|
| 12 |
+
if abs(x_angle) < 30 and abs(y_angle) < 30 and distance < 150:
|
| 13 |
+
return cls.stabilized_fractional(
|
| 14 |
+
guidance_scale=4.0,
|
| 15 |
+
stabilization_level=0.02,
|
| 16 |
+
freq_scale=0.4,
|
| 17 |
+
visualize=visualize,
|
| 18 |
+
)
|
| 19 |
+
else:
|
| 20 |
+
return cls.stabilized_vanilla(
|
| 21 |
+
guidance_scale=4.0,
|
| 22 |
+
stabilization_level=0.02,
|
| 23 |
+
visualize=visualize,
|
| 24 |
+
)
|