Spaces:
Runtime error
Runtime error
update
Browse files- .gitmodules +3 -0
- app.py +153 -70
- demo.py +88 -21
- models/pipelines.py +327 -122
.gitmodules
CHANGED
|
@@ -1,3 +1,6 @@
|
|
| 1 |
[submodule "submodules/MoGe"]
|
| 2 |
path = submodules/MoGe
|
| 3 |
url = https://github.com/microsoft/MoGe.git
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
[submodule "submodules/MoGe"]
|
| 2 |
path = submodules/MoGe
|
| 3 |
url = https://github.com/microsoft/MoGe.git
|
| 4 |
+
[submodule "submodules/vggt"]
|
| 5 |
+
path = submodules/vggt
|
| 6 |
+
url = https://github.com/facebookresearch/vggt.git
|
app.py
CHANGED
|
@@ -16,6 +16,7 @@ sys.path.append(project_root)
|
|
| 16 |
|
| 17 |
try:
|
| 18 |
sys.path.append(os.path.join(project_root, "submodules/MoGe"))
|
|
|
|
| 19 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 20 |
except:
|
| 21 |
print("Warning: MoGe not found, motion transfer will not be applied")
|
|
@@ -27,6 +28,8 @@ hf_hub_download(repo_id="EXCAI/Diffusion-As-Shader", filename='spatracker/spaT_f
|
|
| 27 |
|
| 28 |
from models.pipelines import DiffusionAsShaderPipeline, FirstFrameRepainter, CameraMotionGenerator, ObjectMotionGenerator
|
| 29 |
from submodules.MoGe.moge.model import MoGeModel
|
|
|
|
|
|
|
| 30 |
|
| 31 |
# Parse command line arguments
|
| 32 |
parser = argparse.ArgumentParser(description="Diffusion as Shader Web UI")
|
|
@@ -47,6 +50,7 @@ os.makedirs("outputs", exist_ok=True)
|
|
| 47 |
# Create project tmp directory instead of using system temp
|
| 48 |
os.makedirs(os.path.join(project_root, "tmp"), exist_ok=True)
|
| 49 |
os.makedirs(os.path.join(project_root, "tmp", "gradio"), exist_ok=True)
|
|
|
|
| 50 |
def load_media(media_path, max_frames=49, transform=None):
|
| 51 |
"""Load video or image frames and convert to tensor
|
| 52 |
|
|
@@ -69,22 +73,52 @@ def load_media(media_path, max_frames=49, transform=None):
|
|
| 69 |
is_video = ext in ['.mp4', '.avi', '.mov']
|
| 70 |
|
| 71 |
if is_video:
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
else:
|
| 75 |
# Handle image as single frame
|
| 76 |
image = load_image(media_path)
|
| 77 |
frames = [image]
|
| 78 |
fps = 8 # Default fps for images
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
if len(frames) > max_frames:
|
| 82 |
-
frames = frames[:max_frames]
|
| 83 |
-
elif len(frames) < max_frames:
|
| 84 |
-
last_frame = frames[-1]
|
| 85 |
while len(frames) < max_frames:
|
| 86 |
-
frames.append(
|
| 87 |
-
|
| 88 |
# Convert frames to tensor
|
| 89 |
video_tensor = torch.stack([transform(frame) for frame in frames])
|
| 90 |
|
|
@@ -131,6 +165,7 @@ def save_uploaded_file(file):
|
|
| 131 |
|
| 132 |
das_pipeline = None
|
| 133 |
moge_model = None
|
|
|
|
| 134 |
|
| 135 |
@spaces.GPU
|
| 136 |
def get_das_pipeline():
|
|
@@ -147,6 +182,13 @@ def get_moge_model():
|
|
| 147 |
moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(das.device)
|
| 148 |
return moge_model
|
| 149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image):
|
| 152 |
"""Process video motion transfer task"""
|
|
@@ -154,19 +196,20 @@ def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image)
|
|
| 154 |
# Save uploaded files
|
| 155 |
input_video_path = save_uploaded_file(source)
|
| 156 |
if input_video_path is None:
|
| 157 |
-
return None
|
| 158 |
|
| 159 |
print(f"DEBUG: Repaint option: {mt_repaint_option}")
|
| 160 |
print(f"DEBUG: Repaint image: {mt_repaint_image}")
|
| 161 |
|
| 162 |
-
|
| 163 |
das = get_das_pipeline()
|
| 164 |
video_tensor, fps, is_video = load_media(input_video_path)
|
|
|
|
|
|
|
| 165 |
if not is_video:
|
| 166 |
tracking_method = "moge"
|
| 167 |
print("Image input detected, using MoGe for tracking video generation.")
|
| 168 |
else:
|
| 169 |
-
tracking_method = "
|
| 170 |
|
| 171 |
repaint_img_tensor = None
|
| 172 |
if mt_repaint_image is not None:
|
|
@@ -180,7 +223,9 @@ def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image)
|
|
| 180 |
prompt=prompt,
|
| 181 |
depth_path=None
|
| 182 |
)
|
|
|
|
| 183 |
tracking_tensor = None
|
|
|
|
| 184 |
if tracking_method == "moge":
|
| 185 |
moge = get_moge_model()
|
| 186 |
infer_result = moge.infer(video_tensor[0].to(das.device)) # [C, H, W] in range [0,1]
|
|
@@ -195,32 +240,31 @@ def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image)
|
|
| 195 |
|
| 196 |
pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3]
|
| 197 |
|
| 198 |
-
|
| 199 |
pred_tracks.cpu().numpy(),
|
| 200 |
infer_result["mask"].cpu().numpy()
|
| 201 |
)
|
| 202 |
print('Export tracking video via MoGe')
|
| 203 |
else:
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
print('Export tracking video via
|
| 208 |
|
| 209 |
output_path = das.apply_tracking(
|
| 210 |
video_tensor=video_tensor,
|
| 211 |
-
fps=
|
| 212 |
tracking_tensor=tracking_tensor,
|
| 213 |
img_cond_tensor=repaint_img_tensor,
|
| 214 |
prompt=prompt,
|
| 215 |
checkpoint_path=DEFAULT_MODEL_PATH
|
| 216 |
)
|
| 217 |
|
| 218 |
-
return output_path
|
| 219 |
except Exception as e:
|
| 220 |
import traceback
|
| 221 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
| 222 |
-
return None
|
| 223 |
-
|
| 224 |
|
| 225 |
def process_camera_control(source, prompt, camera_motion, tracking_method):
|
| 226 |
"""Process camera control task"""
|
|
@@ -228,17 +272,18 @@ def process_camera_control(source, prompt, camera_motion, tracking_method):
|
|
| 228 |
# Save uploaded files
|
| 229 |
input_media_path = save_uploaded_file(source)
|
| 230 |
if input_media_path is None:
|
| 231 |
-
return None
|
| 232 |
|
| 233 |
print(f"DEBUG: Camera motion: '{camera_motion}'")
|
| 234 |
print(f"DEBUG: Tracking method: '{tracking_method}'")
|
| 235 |
|
| 236 |
das = get_das_pipeline()
|
| 237 |
-
|
| 238 |
video_tensor, fps, is_video = load_media(input_media_path)
|
| 239 |
-
|
|
|
|
|
|
|
| 240 |
tracking_method = "moge"
|
| 241 |
-
print("Image input detected
|
| 242 |
|
| 243 |
cam_motion = CameraMotionGenerator(camera_motion)
|
| 244 |
repaint_img_tensor = None
|
|
@@ -267,32 +312,54 @@ def process_camera_control(source, prompt, camera_motion, tracking_method):
|
|
| 267 |
)
|
| 268 |
print('Export tracking video via MoGe')
|
| 269 |
else:
|
| 270 |
-
|
| 271 |
-
pred_tracks, pred_visibility
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
if camera_motion:
|
| 273 |
poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
|
| 274 |
-
|
|
|
|
| 275 |
print("Camera motion applied")
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
print('Export tracking video via
|
| 279 |
-
|
| 280 |
|
| 281 |
output_path = das.apply_tracking(
|
| 282 |
video_tensor=video_tensor,
|
| 283 |
-
fps=
|
| 284 |
tracking_tensor=tracking_tensor,
|
| 285 |
img_cond_tensor=repaint_img_tensor,
|
| 286 |
prompt=prompt,
|
| 287 |
checkpoint_path=DEFAULT_MODEL_PATH
|
| 288 |
)
|
| 289 |
|
| 290 |
-
return output_path
|
| 291 |
except Exception as e:
|
| 292 |
import traceback
|
| 293 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
| 294 |
-
return None
|
| 295 |
-
|
| 296 |
|
| 297 |
def process_object_manipulation(source, prompt, object_motion, object_mask, tracking_method):
|
| 298 |
"""Process object manipulation task"""
|
|
@@ -300,21 +367,21 @@ def process_object_manipulation(source, prompt, object_motion, object_mask, trac
|
|
| 300 |
# Save uploaded files
|
| 301 |
input_image_path = save_uploaded_file(source)
|
| 302 |
if input_image_path is None:
|
| 303 |
-
return None
|
| 304 |
|
| 305 |
object_mask_path = save_uploaded_file(object_mask)
|
| 306 |
if object_mask_path is None:
|
| 307 |
print("Object mask not provided")
|
| 308 |
-
return None
|
| 309 |
-
|
| 310 |
|
| 311 |
das = get_das_pipeline()
|
| 312 |
video_tensor, fps, is_video = load_media(input_image_path)
|
| 313 |
-
|
|
|
|
|
|
|
| 314 |
tracking_method = "moge"
|
| 315 |
-
print("Image input detected
|
| 316 |
|
| 317 |
-
|
| 318 |
mask_image = Image.open(object_mask_path).convert('L')
|
| 319 |
mask_image = transforms.Resize((480, 720))(mask_image)
|
| 320 |
mask = torch.from_numpy(np.array(mask_image) > 127)
|
|
@@ -322,10 +389,10 @@ def process_object_manipulation(source, prompt, object_motion, object_mask, trac
|
|
| 322 |
motion_generator = ObjectMotionGenerator(device=das.device)
|
| 323 |
repaint_img_tensor = None
|
| 324 |
tracking_tensor = None
|
|
|
|
| 325 |
if tracking_method == "moge":
|
| 326 |
moge = get_moge_model()
|
| 327 |
|
| 328 |
-
|
| 329 |
infer_result = moge.infer(video_tensor[0].to(das.device)) # [C, H, W] in range [0,1]
|
| 330 |
H, W = infer_result["points"].shape[0:2]
|
| 331 |
pred_tracks = infer_result["points"].unsqueeze(0).repeat(49, 1, 1, 1) #[T, H, W, 3]
|
|
@@ -342,7 +409,6 @@ def process_object_manipulation(source, prompt, object_motion, object_mask, trac
|
|
| 342 |
poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1)
|
| 343 |
pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3)
|
| 344 |
|
| 345 |
-
|
| 346 |
cam_motion = CameraMotionGenerator(None)
|
| 347 |
cam_motion.set_intr(infer_result["intrinsics"])
|
| 348 |
pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3]
|
|
@@ -353,9 +419,27 @@ def process_object_manipulation(source, prompt, object_motion, object_mask, trac
|
|
| 353 |
)
|
| 354 |
print('Export tracking video via MoGe')
|
| 355 |
else:
|
|
|
|
|
|
|
| 356 |
|
| 357 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
|
| 360 |
pred_tracks = motion_generator.apply_motion(
|
| 361 |
pred_tracks=pred_tracks.squeeze(),
|
|
@@ -363,30 +447,27 @@ def process_object_manipulation(source, prompt, object_motion, object_mask, trac
|
|
| 363 |
motion_type=object_motion,
|
| 364 |
distance=50,
|
| 365 |
num_frames=49,
|
| 366 |
-
tracking_method="
|
| 367 |
-
)
|
| 368 |
print(f"Object motion '{object_motion}' applied using provided mask")
|
| 369 |
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
print('Export tracking video via SpaTracker')
|
| 373 |
-
|
| 374 |
|
| 375 |
output_path = das.apply_tracking(
|
| 376 |
video_tensor=video_tensor,
|
| 377 |
-
fps=
|
| 378 |
tracking_tensor=tracking_tensor,
|
| 379 |
img_cond_tensor=repaint_img_tensor,
|
| 380 |
prompt=prompt,
|
| 381 |
checkpoint_path=DEFAULT_MODEL_PATH
|
| 382 |
)
|
| 383 |
|
| 384 |
-
return output_path
|
| 385 |
except Exception as e:
|
| 386 |
import traceback
|
| 387 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
| 388 |
-
return None
|
| 389 |
-
|
| 390 |
|
| 391 |
def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma_repaint_image):
|
| 392 |
"""Process mesh animation task"""
|
|
@@ -394,15 +475,16 @@ def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma
|
|
| 394 |
# Save uploaded files
|
| 395 |
input_video_path = save_uploaded_file(source)
|
| 396 |
if input_video_path is None:
|
| 397 |
-
return None
|
| 398 |
|
| 399 |
tracking_video_path = save_uploaded_file(tracking_video)
|
| 400 |
if tracking_video_path is None:
|
| 401 |
-
return None
|
| 402 |
-
|
| 403 |
|
| 404 |
das = get_das_pipeline()
|
| 405 |
video_tensor, fps, is_video = load_media(input_video_path)
|
|
|
|
|
|
|
| 406 |
tracking_tensor, tracking_fps, _ = load_media(tracking_video_path)
|
| 407 |
repaint_img_tensor = None
|
| 408 |
if ma_repaint_image is not None:
|
|
@@ -420,18 +502,18 @@ def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma
|
|
| 420 |
|
| 421 |
output_path = das.apply_tracking(
|
| 422 |
video_tensor=video_tensor,
|
| 423 |
-
fps=
|
| 424 |
tracking_tensor=tracking_tensor,
|
| 425 |
img_cond_tensor=repaint_img_tensor,
|
| 426 |
prompt=prompt,
|
| 427 |
checkpoint_path=DEFAULT_MODEL_PATH
|
| 428 |
)
|
| 429 |
|
| 430 |
-
return output_path
|
| 431 |
except Exception as e:
|
| 432 |
import traceback
|
| 433 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
| 434 |
-
return None
|
| 435 |
|
| 436 |
# Create Gradio interface with updated layout
|
| 437 |
with gr.Blocks(title="Diffusion as Shader") as demo:
|
|
@@ -444,6 +526,7 @@ with gr.Blocks(title="Diffusion as Shader") as demo:
|
|
| 444 |
|
| 445 |
with right_column:
|
| 446 |
output_video = gr.Video(label="Generated Video")
|
|
|
|
| 447 |
|
| 448 |
with left_column:
|
| 449 |
source = gr.File(label="Source", file_types=["image", "video"])
|
|
@@ -479,7 +562,7 @@ with gr.Blocks(title="Diffusion as Shader") as demo:
|
|
| 479 |
source, common_prompt,
|
| 480 |
mt_repaint_option, mt_repaint_image
|
| 481 |
],
|
| 482 |
-
outputs=[output_video]
|
| 483 |
)
|
| 484 |
|
| 485 |
# Camera Control tab
|
|
@@ -597,8 +680,8 @@ with gr.Blocks(title="Diffusion as Shader") as demo:
|
|
| 597 |
|
| 598 |
cc_tracking_method = gr.Radio(
|
| 599 |
label="Tracking Method",
|
| 600 |
-
choices=["
|
| 601 |
-
value="
|
| 602 |
)
|
| 603 |
|
| 604 |
# Add run button for Camera Control tab
|
|
@@ -611,7 +694,7 @@ with gr.Blocks(title="Diffusion as Shader") as demo:
|
|
| 611 |
source, common_prompt,
|
| 612 |
cc_camera_motion, cc_tracking_method
|
| 613 |
],
|
| 614 |
-
outputs=[output_video]
|
| 615 |
)
|
| 616 |
|
| 617 |
# Object Manipulation tab
|
|
@@ -629,8 +712,8 @@ with gr.Blocks(title="Diffusion as Shader") as demo:
|
|
| 629 |
)
|
| 630 |
om_tracking_method = gr.Radio(
|
| 631 |
label="Tracking Method",
|
| 632 |
-
choices=["
|
| 633 |
-
value="
|
| 634 |
)
|
| 635 |
|
| 636 |
# Add run button for Object Manipulation tab
|
|
@@ -643,7 +726,7 @@ with gr.Blocks(title="Diffusion as Shader") as demo:
|
|
| 643 |
source, common_prompt,
|
| 644 |
om_object_motion, om_object_mask, om_tracking_method
|
| 645 |
],
|
| 646 |
-
outputs=[output_video]
|
| 647 |
)
|
| 648 |
|
| 649 |
# Animating meshes to video tab
|
|
@@ -683,7 +766,7 @@ with gr.Blocks(title="Diffusion as Shader") as demo:
|
|
| 683 |
source, common_prompt,
|
| 684 |
ma_tracking_video, ma_repaint_option, ma_repaint_image
|
| 685 |
],
|
| 686 |
-
outputs=[output_video]
|
| 687 |
)
|
| 688 |
|
| 689 |
# Launch interface
|
|
|
|
| 16 |
|
| 17 |
try:
|
| 18 |
sys.path.append(os.path.join(project_root, "submodules/MoGe"))
|
| 19 |
+
sys.path.append(os.path.join(project_root, "submodules/vggt"))
|
| 20 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 21 |
except:
|
| 22 |
print("Warning: MoGe not found, motion transfer will not be applied")
|
|
|
|
| 28 |
|
| 29 |
from models.pipelines import DiffusionAsShaderPipeline, FirstFrameRepainter, CameraMotionGenerator, ObjectMotionGenerator
|
| 30 |
from submodules.MoGe.moge.model import MoGeModel
|
| 31 |
+
from submodules.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
| 32 |
+
from submodules.vggt.vggt.models.vggt import VGGT
|
| 33 |
|
| 34 |
# Parse command line arguments
|
| 35 |
parser = argparse.ArgumentParser(description="Diffusion as Shader Web UI")
|
|
|
|
| 50 |
# Create project tmp directory instead of using system temp
|
| 51 |
os.makedirs(os.path.join(project_root, "tmp"), exist_ok=True)
|
| 52 |
os.makedirs(os.path.join(project_root, "tmp", "gradio"), exist_ok=True)
|
| 53 |
+
|
| 54 |
def load_media(media_path, max_frames=49, transform=None):
|
| 55 |
"""Load video or image frames and convert to tensor
|
| 56 |
|
|
|
|
| 73 |
is_video = ext in ['.mp4', '.avi', '.mov']
|
| 74 |
|
| 75 |
if is_video:
|
| 76 |
+
# Load video file info
|
| 77 |
+
video_clip = VideoFileClip(media_path)
|
| 78 |
+
duration = video_clip.duration
|
| 79 |
+
original_fps = video_clip.fps
|
| 80 |
+
|
| 81 |
+
# Case 1: Video longer than 6 seconds, sample first 6 seconds + 1 frame
|
| 82 |
+
if duration > 6.0:
|
| 83 |
+
sampling_fps = 8 # 8 frames per second
|
| 84 |
+
frames = load_video(media_path, sampling_fps=sampling_fps, max_frames=max_frames)
|
| 85 |
+
fps = sampling_fps
|
| 86 |
+
# Cases 2 and 3: Video shorter than 6 seconds
|
| 87 |
+
else:
|
| 88 |
+
# Load all frames
|
| 89 |
+
frames = load_video(media_path)
|
| 90 |
+
|
| 91 |
+
# Case 2: Total frames less than max_frames, need interpolation
|
| 92 |
+
if len(frames) < max_frames:
|
| 93 |
+
fps = len(frames) / duration # Keep original fps
|
| 94 |
+
|
| 95 |
+
# Evenly interpolate to max_frames
|
| 96 |
+
indices = np.linspace(0, len(frames) - 1, max_frames)
|
| 97 |
+
new_frames = []
|
| 98 |
+
for i in indices:
|
| 99 |
+
idx = int(i)
|
| 100 |
+
new_frames.append(frames[idx])
|
| 101 |
+
frames = new_frames
|
| 102 |
+
# Case 3: Total frames more than max_frames but video less than 6 seconds
|
| 103 |
+
else:
|
| 104 |
+
# Evenly sample to max_frames
|
| 105 |
+
indices = np.linspace(0, len(frames) - 1, max_frames)
|
| 106 |
+
new_frames = []
|
| 107 |
+
for i in indices:
|
| 108 |
+
idx = int(i)
|
| 109 |
+
new_frames.append(frames[idx])
|
| 110 |
+
frames = new_frames
|
| 111 |
+
fps = max_frames / duration # New fps to maintain duration
|
| 112 |
else:
|
| 113 |
# Handle image as single frame
|
| 114 |
image = load_image(media_path)
|
| 115 |
frames = [image]
|
| 116 |
fps = 8 # Default fps for images
|
| 117 |
+
|
| 118 |
+
# Duplicate frame to max_frames
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
while len(frames) < max_frames:
|
| 120 |
+
frames.append(frames[0].copy())
|
| 121 |
+
|
| 122 |
# Convert frames to tensor
|
| 123 |
video_tensor = torch.stack([transform(frame) for frame in frames])
|
| 124 |
|
|
|
|
| 165 |
|
| 166 |
das_pipeline = None
|
| 167 |
moge_model = None
|
| 168 |
+
vggt_model = None
|
| 169 |
|
| 170 |
@spaces.GPU
|
| 171 |
def get_das_pipeline():
|
|
|
|
| 182 |
moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(das.device)
|
| 183 |
return moge_model
|
| 184 |
|
| 185 |
+
@spaces.GPU
|
| 186 |
+
def get_vggt_model():
|
| 187 |
+
global vggt_model
|
| 188 |
+
if vggt_model is None:
|
| 189 |
+
das = get_das_pipeline()
|
| 190 |
+
vggt_model = VGGT.from_pretrained("facebook/VGGT-1B").to(das.device)
|
| 191 |
+
return vggt_model
|
| 192 |
|
| 193 |
def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image):
|
| 194 |
"""Process video motion transfer task"""
|
|
|
|
| 196 |
# Save uploaded files
|
| 197 |
input_video_path = save_uploaded_file(source)
|
| 198 |
if input_video_path is None:
|
| 199 |
+
return None, None
|
| 200 |
|
| 201 |
print(f"DEBUG: Repaint option: {mt_repaint_option}")
|
| 202 |
print(f"DEBUG: Repaint image: {mt_repaint_image}")
|
| 203 |
|
|
|
|
| 204 |
das = get_das_pipeline()
|
| 205 |
video_tensor, fps, is_video = load_media(input_video_path)
|
| 206 |
+
das.fps = fps # 设置 das.fps 为 load_media 返回的 fps
|
| 207 |
+
|
| 208 |
if not is_video:
|
| 209 |
tracking_method = "moge"
|
| 210 |
print("Image input detected, using MoGe for tracking video generation.")
|
| 211 |
else:
|
| 212 |
+
tracking_method = "cotracker"
|
| 213 |
|
| 214 |
repaint_img_tensor = None
|
| 215 |
if mt_repaint_image is not None:
|
|
|
|
| 223 |
prompt=prompt,
|
| 224 |
depth_path=None
|
| 225 |
)
|
| 226 |
+
|
| 227 |
tracking_tensor = None
|
| 228 |
+
tracking_path = None
|
| 229 |
if tracking_method == "moge":
|
| 230 |
moge = get_moge_model()
|
| 231 |
infer_result = moge.infer(video_tensor[0].to(das.device)) # [C, H, W] in range [0,1]
|
|
|
|
| 240 |
|
| 241 |
pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3]
|
| 242 |
|
| 243 |
+
tracking_path, tracking_tensor = das.visualize_tracking_moge(
|
| 244 |
pred_tracks.cpu().numpy(),
|
| 245 |
infer_result["mask"].cpu().numpy()
|
| 246 |
)
|
| 247 |
print('Export tracking video via MoGe')
|
| 248 |
else:
|
| 249 |
+
# 使用 cotracker
|
| 250 |
+
pred_tracks, pred_visibility = das.generate_tracking_cotracker(video_tensor)
|
| 251 |
+
tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks, pred_visibility)
|
| 252 |
+
print('Export tracking video via cotracker')
|
| 253 |
|
| 254 |
output_path = das.apply_tracking(
|
| 255 |
video_tensor=video_tensor,
|
| 256 |
+
fps=fps, # 使用 load_media 返回的 fps
|
| 257 |
tracking_tensor=tracking_tensor,
|
| 258 |
img_cond_tensor=repaint_img_tensor,
|
| 259 |
prompt=prompt,
|
| 260 |
checkpoint_path=DEFAULT_MODEL_PATH
|
| 261 |
)
|
| 262 |
|
| 263 |
+
return tracking_path, output_path
|
| 264 |
except Exception as e:
|
| 265 |
import traceback
|
| 266 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
| 267 |
+
return None, None
|
|
|
|
| 268 |
|
| 269 |
def process_camera_control(source, prompt, camera_motion, tracking_method):
|
| 270 |
"""Process camera control task"""
|
|
|
|
| 272 |
# Save uploaded files
|
| 273 |
input_media_path = save_uploaded_file(source)
|
| 274 |
if input_media_path is None:
|
| 275 |
+
return None, None
|
| 276 |
|
| 277 |
print(f"DEBUG: Camera motion: '{camera_motion}'")
|
| 278 |
print(f"DEBUG: Tracking method: '{tracking_method}'")
|
| 279 |
|
| 280 |
das = get_das_pipeline()
|
|
|
|
| 281 |
video_tensor, fps, is_video = load_media(input_media_path)
|
| 282 |
+
das.fps = fps # 设置 das.fps 为 load_media 返回的 fps
|
| 283 |
+
|
| 284 |
+
if not is_video:
|
| 285 |
tracking_method = "moge"
|
| 286 |
+
print("Image input detected, switching to MoGe")
|
| 287 |
|
| 288 |
cam_motion = CameraMotionGenerator(camera_motion)
|
| 289 |
repaint_img_tensor = None
|
|
|
|
| 312 |
)
|
| 313 |
print('Export tracking video via MoGe')
|
| 314 |
else:
|
| 315 |
+
# 使用 cotracker
|
| 316 |
+
pred_tracks, pred_visibility = das.generate_tracking_cotracker(video_tensor)
|
| 317 |
+
|
| 318 |
+
t, c, h, w = video_tensor.shape
|
| 319 |
+
new_width = 518
|
| 320 |
+
new_height = round(h * (new_width / w) / 14) * 14
|
| 321 |
+
resize_transform = transforms.Resize((new_height, new_width), interpolation=Image.BICUBIC)
|
| 322 |
+
video_vggt = resize_transform(video_tensor) # [T, C, H, W]
|
| 323 |
+
|
| 324 |
+
if new_height > 518:
|
| 325 |
+
start_y = (new_height - 518) // 2
|
| 326 |
+
video_vggt = video_vggt[:, :, start_y:start_y + 518, :]
|
| 327 |
+
|
| 328 |
+
vggt_model = get_vggt_model()
|
| 329 |
+
|
| 330 |
+
with torch.no_grad():
|
| 331 |
+
with torch.cuda.amp.autocast(dtype=das.dtype):
|
| 332 |
+
video_vggt = video_vggt.unsqueeze(0) # [1, T, C, H, W]
|
| 333 |
+
aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_vggt.to(das.device))
|
| 334 |
+
|
| 335 |
+
extr, intr = pose_encoding_to_extri_intri(vggt_model.camera_head(aggregated_tokens_list)[-1], video_vggt.shape[-2:])
|
| 336 |
+
|
| 337 |
+
cam_motion.set_intr(intr)
|
| 338 |
+
cam_motion.set_extr(extr)
|
| 339 |
+
|
| 340 |
if camera_motion:
|
| 341 |
poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
|
| 342 |
+
pred_tracks_world = cam_motion.s2w_vggt(pred_tracks, extr, intr)
|
| 343 |
+
pred_tracks = cam_motion.w2s_vggt(pred_tracks_world, extr, intr, poses) # [T, N, 3]
|
| 344 |
print("Camera motion applied")
|
| 345 |
+
|
| 346 |
+
tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks, None)
|
| 347 |
+
print('Export tracking video via cotracker')
|
|
|
|
| 348 |
|
| 349 |
output_path = das.apply_tracking(
|
| 350 |
video_tensor=video_tensor,
|
| 351 |
+
fps=fps, # 使用 load_media 返回的 fps
|
| 352 |
tracking_tensor=tracking_tensor,
|
| 353 |
img_cond_tensor=repaint_img_tensor,
|
| 354 |
prompt=prompt,
|
| 355 |
checkpoint_path=DEFAULT_MODEL_PATH
|
| 356 |
)
|
| 357 |
|
| 358 |
+
return tracking_path, output_path
|
| 359 |
except Exception as e:
|
| 360 |
import traceback
|
| 361 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
| 362 |
+
return None, None
|
|
|
|
| 363 |
|
| 364 |
def process_object_manipulation(source, prompt, object_motion, object_mask, tracking_method):
|
| 365 |
"""Process object manipulation task"""
|
|
|
|
| 367 |
# Save uploaded files
|
| 368 |
input_image_path = save_uploaded_file(source)
|
| 369 |
if input_image_path is None:
|
| 370 |
+
return None, None
|
| 371 |
|
| 372 |
object_mask_path = save_uploaded_file(object_mask)
|
| 373 |
if object_mask_path is None:
|
| 374 |
print("Object mask not provided")
|
| 375 |
+
return None, None
|
|
|
|
| 376 |
|
| 377 |
das = get_das_pipeline()
|
| 378 |
video_tensor, fps, is_video = load_media(input_image_path)
|
| 379 |
+
das.fps = fps # 设置 das.fps 为 load_media 返回的 fps
|
| 380 |
+
|
| 381 |
+
if not is_video:
|
| 382 |
tracking_method = "moge"
|
| 383 |
+
print("Image input detected, switching to MoGe")
|
| 384 |
|
|
|
|
| 385 |
mask_image = Image.open(object_mask_path).convert('L')
|
| 386 |
mask_image = transforms.Resize((480, 720))(mask_image)
|
| 387 |
mask = torch.from_numpy(np.array(mask_image) > 127)
|
|
|
|
| 389 |
motion_generator = ObjectMotionGenerator(device=das.device)
|
| 390 |
repaint_img_tensor = None
|
| 391 |
tracking_tensor = None
|
| 392 |
+
|
| 393 |
if tracking_method == "moge":
|
| 394 |
moge = get_moge_model()
|
| 395 |
|
|
|
|
| 396 |
infer_result = moge.infer(video_tensor[0].to(das.device)) # [C, H, W] in range [0,1]
|
| 397 |
H, W = infer_result["points"].shape[0:2]
|
| 398 |
pred_tracks = infer_result["points"].unsqueeze(0).repeat(49, 1, 1, 1) #[T, H, W, 3]
|
|
|
|
| 409 |
poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1)
|
| 410 |
pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3)
|
| 411 |
|
|
|
|
| 412 |
cam_motion = CameraMotionGenerator(None)
|
| 413 |
cam_motion.set_intr(infer_result["intrinsics"])
|
| 414 |
pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3]
|
|
|
|
| 419 |
)
|
| 420 |
print('Export tracking video via MoGe')
|
| 421 |
else:
|
| 422 |
+
# 使用 cotracker
|
| 423 |
+
pred_tracks, pred_visibility = das.generate_tracking_cotracker(video_tensor)
|
| 424 |
|
| 425 |
+
t, c, h, w = video_tensor.shape
|
| 426 |
+
new_width = 518
|
| 427 |
+
new_height = round(h * (new_width / w) / 14) * 14
|
| 428 |
+
resize_transform = transforms.Resize((new_height, new_width), interpolation=Image.BICUBIC)
|
| 429 |
+
video_vggt = resize_transform(video_tensor) # [T, C, H, W]
|
| 430 |
|
| 431 |
+
if new_height > 518:
|
| 432 |
+
start_y = (new_height - 518) // 2
|
| 433 |
+
video_vggt = video_vggt[:, :, start_y:start_y + 518, :]
|
| 434 |
+
|
| 435 |
+
vggt_model = get_vggt_model()
|
| 436 |
+
|
| 437 |
+
with torch.no_grad():
|
| 438 |
+
with torch.cuda.amp.autocast(dtype=das.dtype):
|
| 439 |
+
video_vggt = video_vggt.unsqueeze(0) # [1, T, C, H, W]
|
| 440 |
+
aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_vggt.to(das.device))
|
| 441 |
+
|
| 442 |
+
extr, intr = pose_encoding_to_extri_intri(vggt_model.camera_head(aggregated_tokens_list)[-1], video_vggt.shape[-2:])
|
| 443 |
|
| 444 |
pred_tracks = motion_generator.apply_motion(
|
| 445 |
pred_tracks=pred_tracks.squeeze(),
|
|
|
|
| 447 |
motion_type=object_motion,
|
| 448 |
distance=50,
|
| 449 |
num_frames=49,
|
| 450 |
+
tracking_method="cotracker"
|
| 451 |
+
)
|
| 452 |
print(f"Object motion '{object_motion}' applied using provided mask")
|
| 453 |
|
| 454 |
+
tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks.unsqueeze(0), None)
|
| 455 |
+
print('Export tracking video via cotracker')
|
|
|
|
|
|
|
| 456 |
|
| 457 |
output_path = das.apply_tracking(
|
| 458 |
video_tensor=video_tensor,
|
| 459 |
+
fps=fps, # 使用 load_media 返回的 fps
|
| 460 |
tracking_tensor=tracking_tensor,
|
| 461 |
img_cond_tensor=repaint_img_tensor,
|
| 462 |
prompt=prompt,
|
| 463 |
checkpoint_path=DEFAULT_MODEL_PATH
|
| 464 |
)
|
| 465 |
|
| 466 |
+
return tracking_path, output_path
|
| 467 |
except Exception as e:
|
| 468 |
import traceback
|
| 469 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
| 470 |
+
return None, None
|
|
|
|
| 471 |
|
| 472 |
def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma_repaint_image):
|
| 473 |
"""Process mesh animation task"""
|
|
|
|
| 475 |
# Save uploaded files
|
| 476 |
input_video_path = save_uploaded_file(source)
|
| 477 |
if input_video_path is None:
|
| 478 |
+
return None, None
|
| 479 |
|
| 480 |
tracking_video_path = save_uploaded_file(tracking_video)
|
| 481 |
if tracking_video_path is None:
|
| 482 |
+
return None, None
|
|
|
|
| 483 |
|
| 484 |
das = get_das_pipeline()
|
| 485 |
video_tensor, fps, is_video = load_media(input_video_path)
|
| 486 |
+
das.fps = fps # 设置 das.fps 为 load_media 返回的 fps
|
| 487 |
+
|
| 488 |
tracking_tensor, tracking_fps, _ = load_media(tracking_video_path)
|
| 489 |
repaint_img_tensor = None
|
| 490 |
if ma_repaint_image is not None:
|
|
|
|
| 502 |
|
| 503 |
output_path = das.apply_tracking(
|
| 504 |
video_tensor=video_tensor,
|
| 505 |
+
fps=fps, # 使用 load_media 返回的 fps
|
| 506 |
tracking_tensor=tracking_tensor,
|
| 507 |
img_cond_tensor=repaint_img_tensor,
|
| 508 |
prompt=prompt,
|
| 509 |
checkpoint_path=DEFAULT_MODEL_PATH
|
| 510 |
)
|
| 511 |
|
| 512 |
+
return tracking_video_path, output_path
|
| 513 |
except Exception as e:
|
| 514 |
import traceback
|
| 515 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
| 516 |
+
return None, None
|
| 517 |
|
| 518 |
# Create Gradio interface with updated layout
|
| 519 |
with gr.Blocks(title="Diffusion as Shader") as demo:
|
|
|
|
| 526 |
|
| 527 |
with right_column:
|
| 528 |
output_video = gr.Video(label="Generated Video")
|
| 529 |
+
tracking_video = gr.Video(label="Tracking Video")
|
| 530 |
|
| 531 |
with left_column:
|
| 532 |
source = gr.File(label="Source", file_types=["image", "video"])
|
|
|
|
| 562 |
source, common_prompt,
|
| 563 |
mt_repaint_option, mt_repaint_image
|
| 564 |
],
|
| 565 |
+
outputs=[tracking_video, output_video]
|
| 566 |
)
|
| 567 |
|
| 568 |
# Camera Control tab
|
|
|
|
| 680 |
|
| 681 |
cc_tracking_method = gr.Radio(
|
| 682 |
label="Tracking Method",
|
| 683 |
+
choices=["moge", "cotracker"],
|
| 684 |
+
value="cotracker"
|
| 685 |
)
|
| 686 |
|
| 687 |
# Add run button for Camera Control tab
|
|
|
|
| 694 |
source, common_prompt,
|
| 695 |
cc_camera_motion, cc_tracking_method
|
| 696 |
],
|
| 697 |
+
outputs=[tracking_video, output_video]
|
| 698 |
)
|
| 699 |
|
| 700 |
# Object Manipulation tab
|
|
|
|
| 712 |
)
|
| 713 |
om_tracking_method = gr.Radio(
|
| 714 |
label="Tracking Method",
|
| 715 |
+
choices=["moge", "cotracker"],
|
| 716 |
+
value="cotracker"
|
| 717 |
)
|
| 718 |
|
| 719 |
# Add run button for Object Manipulation tab
|
|
|
|
| 726 |
source, common_prompt,
|
| 727 |
om_object_motion, om_object_mask, om_tracking_method
|
| 728 |
],
|
| 729 |
+
outputs=[tracking_video, output_video]
|
| 730 |
)
|
| 731 |
|
| 732 |
# Animating meshes to video tab
|
|
|
|
| 766 |
source, common_prompt,
|
| 767 |
ma_tracking_video, ma_repaint_option, ma_repaint_image
|
| 768 |
],
|
| 769 |
+
outputs=[tracking_video, output_video]
|
| 770 |
)
|
| 771 |
|
| 772 |
# Launch interface
|
demo.py
CHANGED
|
@@ -5,6 +5,7 @@ from PIL import Image
|
|
| 5 |
project_root = os.path.dirname(os.path.abspath(__file__))
|
| 6 |
try:
|
| 7 |
sys.path.append(os.path.join(project_root, "submodules/MoGe"))
|
|
|
|
| 8 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 9 |
except:
|
| 10 |
print("Warning: MoGe not found, motion transfer will not be applied")
|
|
@@ -18,6 +19,8 @@ from diffusers.utils import load_image, load_video
|
|
| 18 |
|
| 19 |
from models.pipelines import DiffusionAsShaderPipeline, FirstFrameRepainter, CameraMotionGenerator, ObjectMotionGenerator
|
| 20 |
from submodules.MoGe.moge.model import MoGeModel
|
|
|
|
|
|
|
| 21 |
|
| 22 |
def load_media(media_path, max_frames=49, transform=None):
|
| 23 |
"""Load video or image frames and convert to tensor
|
|
@@ -28,7 +31,7 @@ def load_media(media_path, max_frames=49, transform=None):
|
|
| 28 |
transform (callable): Transform to apply to frames
|
| 29 |
|
| 30 |
Returns:
|
| 31 |
-
Tuple[torch.Tensor, float]: Video tensor [T,C,H,W] and
|
| 32 |
"""
|
| 33 |
if transform is None:
|
| 34 |
transform = transforms.Compose([
|
|
@@ -41,22 +44,52 @@ def load_media(media_path, max_frames=49, transform=None):
|
|
| 41 |
is_video = ext in ['.mp4', '.avi', '.mov']
|
| 42 |
|
| 43 |
if is_video:
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
else:
|
| 47 |
# Handle image as single frame
|
| 48 |
image = load_image(media_path)
|
| 49 |
frames = [image]
|
| 50 |
fps = 8 # Default fps for images
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
if len(frames) > max_frames:
|
| 54 |
-
frames = frames[:max_frames]
|
| 55 |
-
elif len(frames) < max_frames:
|
| 56 |
-
last_frame = frames[-1]
|
| 57 |
while len(frames) < max_frames:
|
| 58 |
-
frames.append(
|
| 59 |
-
|
| 60 |
# Convert frames to tensor
|
| 61 |
video_tensor = torch.stack([transform(frame) for frame in frames])
|
| 62 |
|
|
@@ -77,8 +110,8 @@ if __name__ == "__main__":
|
|
| 77 |
help='Camera motion mode: "trans <dx> <dy> <dz>" or "rot <axis> <angle>" or "spiral <radius>"')
|
| 78 |
parser.add_argument('--object_motion', type=str, default=None, help='Object motion mode: up/down/left/right')
|
| 79 |
parser.add_argument('--object_mask', type=str, default=None, help='Path to object mask image (binary image)')
|
| 80 |
-
parser.add_argument('--tracking_method', type=str, default='spatracker', choices=['spatracker', 'moge'],
|
| 81 |
-
help='Tracking method to use (spatracker or moge)')
|
| 82 |
args = parser.parse_args()
|
| 83 |
|
| 84 |
# Load input video/image
|
|
@@ -89,6 +122,7 @@ if __name__ == "__main__":
|
|
| 89 |
|
| 90 |
# Initialize pipeline
|
| 91 |
das = DiffusionAsShaderPipeline(gpu_id=args.gpu, output_dir=args.output_dir)
|
|
|
|
| 92 |
if args.tracking_method == "moge" and args.tracking_path is None:
|
| 93 |
moge = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(das.device)
|
| 94 |
|
|
@@ -153,7 +187,7 @@ if __name__ == "__main__":
|
|
| 153 |
poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1)
|
| 154 |
# change pred_tracks into screen coordinate
|
| 155 |
pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3)
|
| 156 |
-
pred_tracks = cam_motion.
|
| 157 |
_, tracking_tensor = das.visualize_tracking_moge(
|
| 158 |
pred_tracks.cpu().numpy(),
|
| 159 |
infer_result["mask"].cpu().numpy()
|
|
@@ -161,13 +195,44 @@ if __name__ == "__main__":
|
|
| 161 |
print('export tracking video via MoGe.')
|
| 162 |
|
| 163 |
else:
|
| 164 |
-
|
| 165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
# Apply camera motion if specified
|
| 168 |
if args.camera_motion:
|
| 169 |
poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
|
| 170 |
-
|
|
|
|
| 171 |
print("Camera motion applied")
|
| 172 |
|
| 173 |
# Apply object motion if specified
|
|
@@ -184,7 +249,7 @@ if __name__ == "__main__":
|
|
| 184 |
motion_generator = ObjectMotionGenerator(device=das.device)
|
| 185 |
|
| 186 |
pred_tracks = motion_generator.apply_motion(
|
| 187 |
-
pred_tracks=pred_tracks
|
| 188 |
mask=mask,
|
| 189 |
motion_type=args.object_motion,
|
| 190 |
distance=50,
|
|
@@ -193,12 +258,14 @@ if __name__ == "__main__":
|
|
| 193 |
).unsqueeze(0)
|
| 194 |
print(f"Object motion '{args.object_motion}' applied using mask from {args.object_mask}")
|
| 195 |
|
| 196 |
-
|
| 197 |
-
|
|
|
|
|
|
|
| 198 |
|
| 199 |
das.apply_tracking(
|
| 200 |
video_tensor=video_tensor,
|
| 201 |
-
fps=
|
| 202 |
tracking_tensor=tracking_tensor,
|
| 203 |
img_cond_tensor=repaint_img_tensor,
|
| 204 |
prompt=args.prompt,
|
|
|
|
| 5 |
project_root = os.path.dirname(os.path.abspath(__file__))
|
| 6 |
try:
|
| 7 |
sys.path.append(os.path.join(project_root, "submodules/MoGe"))
|
| 8 |
+
sys.path.append(os.path.join(project_root, "submodules/vggt"))
|
| 9 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 10 |
except:
|
| 11 |
print("Warning: MoGe not found, motion transfer will not be applied")
|
|
|
|
| 19 |
|
| 20 |
from models.pipelines import DiffusionAsShaderPipeline, FirstFrameRepainter, CameraMotionGenerator, ObjectMotionGenerator
|
| 21 |
from submodules.MoGe.moge.model import MoGeModel
|
| 22 |
+
from submodules.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
| 23 |
+
from submodules.vggt.vggt.models.vggt import VGGT
|
| 24 |
|
| 25 |
def load_media(media_path, max_frames=49, transform=None):
|
| 26 |
"""Load video or image frames and convert to tensor
|
|
|
|
| 31 |
transform (callable): Transform to apply to frames
|
| 32 |
|
| 33 |
Returns:
|
| 34 |
+
Tuple[torch.Tensor, float, bool]: Video tensor [T,C,H,W], FPS, and is_video flag
|
| 35 |
"""
|
| 36 |
if transform is None:
|
| 37 |
transform = transforms.Compose([
|
|
|
|
| 44 |
is_video = ext in ['.mp4', '.avi', '.mov']
|
| 45 |
|
| 46 |
if is_video:
|
| 47 |
+
# Load video file info
|
| 48 |
+
video_clip = VideoFileClip(media_path)
|
| 49 |
+
duration = video_clip.duration
|
| 50 |
+
original_fps = video_clip.fps
|
| 51 |
+
|
| 52 |
+
# Case 1: Video longer than 6 seconds, sample first 6 seconds + 1 frame
|
| 53 |
+
if duration > 6.0:
|
| 54 |
+
sampling_fps = 8 # 8 frames per second
|
| 55 |
+
frames = load_video(media_path, sampling_fps=sampling_fps, max_frames=max_frames)
|
| 56 |
+
fps = sampling_fps
|
| 57 |
+
# Cases 2 and 3: Video shorter than 6 seconds
|
| 58 |
+
else:
|
| 59 |
+
# Load all frames
|
| 60 |
+
frames = load_video(media_path)
|
| 61 |
+
|
| 62 |
+
# Case 2: Total frames less than max_frames, need interpolation
|
| 63 |
+
if len(frames) < max_frames:
|
| 64 |
+
fps = len(frames) / duration # Keep original fps
|
| 65 |
+
|
| 66 |
+
# Evenly interpolate to max_frames
|
| 67 |
+
indices = np.linspace(0, len(frames) - 1, max_frames)
|
| 68 |
+
new_frames = []
|
| 69 |
+
for i in indices:
|
| 70 |
+
idx = int(i)
|
| 71 |
+
new_frames.append(frames[idx])
|
| 72 |
+
frames = new_frames
|
| 73 |
+
# Case 3: Total frames more than max_frames but video less than 6 seconds
|
| 74 |
+
else:
|
| 75 |
+
# Evenly sample to max_frames
|
| 76 |
+
indices = np.linspace(0, len(frames) - 1, max_frames)
|
| 77 |
+
new_frames = []
|
| 78 |
+
for i in indices:
|
| 79 |
+
idx = int(i)
|
| 80 |
+
new_frames.append(frames[idx])
|
| 81 |
+
frames = new_frames
|
| 82 |
+
fps = max_frames / duration # New fps to maintain duration
|
| 83 |
else:
|
| 84 |
# Handle image as single frame
|
| 85 |
image = load_image(media_path)
|
| 86 |
frames = [image]
|
| 87 |
fps = 8 # Default fps for images
|
| 88 |
+
|
| 89 |
+
# Duplicate frame to max_frames
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
while len(frames) < max_frames:
|
| 91 |
+
frames.append(frames[0].copy())
|
| 92 |
+
|
| 93 |
# Convert frames to tensor
|
| 94 |
video_tensor = torch.stack([transform(frame) for frame in frames])
|
| 95 |
|
|
|
|
| 110 |
help='Camera motion mode: "trans <dx> <dy> <dz>" or "rot <axis> <angle>" or "spiral <radius>"')
|
| 111 |
parser.add_argument('--object_motion', type=str, default=None, help='Object motion mode: up/down/left/right')
|
| 112 |
parser.add_argument('--object_mask', type=str, default=None, help='Path to object mask image (binary image)')
|
| 113 |
+
parser.add_argument('--tracking_method', type=str, default='spatracker', choices=['spatracker', 'moge', 'cotracker'],
|
| 114 |
+
help='Tracking method to use (spatracker, cotracker or moge)')
|
| 115 |
args = parser.parse_args()
|
| 116 |
|
| 117 |
# Load input video/image
|
|
|
|
| 122 |
|
| 123 |
# Initialize pipeline
|
| 124 |
das = DiffusionAsShaderPipeline(gpu_id=args.gpu, output_dir=args.output_dir)
|
| 125 |
+
das.fps = fps
|
| 126 |
if args.tracking_method == "moge" and args.tracking_path is None:
|
| 127 |
moge = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(das.device)
|
| 128 |
|
|
|
|
| 187 |
poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1)
|
| 188 |
# change pred_tracks into screen coordinate
|
| 189 |
pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3)
|
| 190 |
+
pred_tracks = cam_motion.w2s_moge(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3]
|
| 191 |
_, tracking_tensor = das.visualize_tracking_moge(
|
| 192 |
pred_tracks.cpu().numpy(),
|
| 193 |
infer_result["mask"].cpu().numpy()
|
|
|
|
| 195 |
print('export tracking video via MoGe.')
|
| 196 |
|
| 197 |
else:
|
| 198 |
+
|
| 199 |
+
if args.tracking_method == "cotracker":
|
| 200 |
+
pred_tracks, pred_visibility = das.generate_tracking_cotracker(video_tensor) # T N 3, T N
|
| 201 |
+
else:
|
| 202 |
+
pred_tracks, pred_visibility, T_Firsts = das.generate_tracking_spatracker(video_tensor) # T N 3, T N, B N
|
| 203 |
+
|
| 204 |
+
# Preprocess video tensor to match VGGT requirements
|
| 205 |
+
t, c, h, w = video_tensor.shape
|
| 206 |
+
new_width = 518
|
| 207 |
+
new_height = round(h * (new_width / w) / 14) * 14
|
| 208 |
+
resize_transform = transforms.Resize((new_height, new_width), interpolation=Image.BICUBIC)
|
| 209 |
+
video_vggt = resize_transform(video_tensor) # [T, C, H, W]
|
| 210 |
+
|
| 211 |
+
if new_height > 518:
|
| 212 |
+
start_y = (new_height - 518) // 2
|
| 213 |
+
video_vggt = video_vggt[:, :, start_y:start_y + 518, :]
|
| 214 |
+
|
| 215 |
+
# Get extrinsic and intrinsic matrices
|
| 216 |
+
vggt_model = VGGT.from_pretrained("facebook/VGGT-1B").to(das.device)
|
| 217 |
+
|
| 218 |
+
with torch.no_grad():
|
| 219 |
+
with torch.cuda.amp.autocast(dtype=das.dtype):
|
| 220 |
+
|
| 221 |
+
video_vggt = video_vggt.unsqueeze(0) # [1, T, C, H, W]
|
| 222 |
+
aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_vggt.to(das.device))
|
| 223 |
+
|
| 224 |
+
# Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
|
| 225 |
+
extr, intr = pose_encoding_to_extri_intri(vggt_model.camera_head(aggregated_tokens_list)[-1], video_vggt.shape[-2:])
|
| 226 |
+
depth_map, depth_conf = vggt_model.depth_head(aggregated_tokens_list, video_vggt, ps_idx)
|
| 227 |
+
|
| 228 |
+
cam_motion.set_intr(intr)
|
| 229 |
+
cam_motion.set_extr(extr)
|
| 230 |
|
| 231 |
# Apply camera motion if specified
|
| 232 |
if args.camera_motion:
|
| 233 |
poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
|
| 234 |
+
pred_tracks_world = cam_motion.s2w_vggt(pred_tracks, extr, intr)
|
| 235 |
+
pred_tracks = cam_motion.w2s_vggt(pred_tracks_world, extr, intr, poses) # [T, N, 3]
|
| 236 |
print("Camera motion applied")
|
| 237 |
|
| 238 |
# Apply object motion if specified
|
|
|
|
| 249 |
motion_generator = ObjectMotionGenerator(device=das.device)
|
| 250 |
|
| 251 |
pred_tracks = motion_generator.apply_motion(
|
| 252 |
+
pred_tracks=pred_tracks,
|
| 253 |
mask=mask,
|
| 254 |
motion_type=args.object_motion,
|
| 255 |
distance=50,
|
|
|
|
| 258 |
).unsqueeze(0)
|
| 259 |
print(f"Object motion '{args.object_motion}' applied using mask from {args.object_mask}")
|
| 260 |
|
| 261 |
+
if args.tracking_method == "cotracker":
|
| 262 |
+
_, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks, pred_visibility)
|
| 263 |
+
else:
|
| 264 |
+
_, tracking_tensor = das.visualize_tracking_spatracker(video_tensor, pred_tracks, pred_visibility, T_Firsts)
|
| 265 |
|
| 266 |
das.apply_tracking(
|
| 267 |
video_tensor=video_tensor,
|
| 268 |
+
fps=fps,
|
| 269 |
tracking_tensor=tracking_tensor,
|
| 270 |
img_cond_tensor=repaint_img_tensor,
|
| 271 |
prompt=args.prompt,
|
models/pipelines.py
CHANGED
|
@@ -22,9 +22,9 @@ from models.spatracker.utils.visualizer import Visualizer
|
|
| 22 |
from models.cogvideox_tracking import CogVideoXImageToVideoPipelineTracking
|
| 23 |
|
| 24 |
from submodules.MoGe.moge.model import MoGeModel
|
|
|
|
| 25 |
from image_gen_aux import DepthPreprocessor
|
| 26 |
from moviepy.editor import ImageSequenceClip
|
| 27 |
-
import spaces
|
| 28 |
|
| 29 |
class DiffusionAsShaderPipeline:
|
| 30 |
def __init__(self, gpu_id=0, output_dir='outputs'):
|
|
@@ -45,6 +45,7 @@ class DiffusionAsShaderPipeline:
|
|
| 45 |
# device
|
| 46 |
self.device = f"cuda:{gpu_id}"
|
| 47 |
torch.cuda.set_device(gpu_id)
|
|
|
|
| 48 |
|
| 49 |
# files
|
| 50 |
self.output_dir = output_dir
|
|
@@ -56,7 +57,6 @@ class DiffusionAsShaderPipeline:
|
|
| 56 |
transforms.ToTensor()
|
| 57 |
])
|
| 58 |
|
| 59 |
-
@spaces.GPU(duration=240)
|
| 60 |
@torch.no_grad()
|
| 61 |
def _infer(
|
| 62 |
self,
|
|
@@ -65,7 +65,7 @@ class DiffusionAsShaderPipeline:
|
|
| 65 |
tracking_tensor: torch.Tensor = None,
|
| 66 |
image_tensor: torch.Tensor = None, # [C,H,W] in range [0,1]
|
| 67 |
output_path: str = "./output.mp4",
|
| 68 |
-
num_inference_steps: int =
|
| 69 |
guidance_scale: float = 6.0,
|
| 70 |
num_videos_per_prompt: int = 1,
|
| 71 |
dtype: torch.dtype = torch.bfloat16,
|
|
@@ -114,6 +114,8 @@ class DiffusionAsShaderPipeline:
|
|
| 114 |
pipe.text_encoder.eval()
|
| 115 |
pipe.vae.eval()
|
| 116 |
|
|
|
|
|
|
|
| 117 |
# Process tracking tensor
|
| 118 |
tracking_maps = tracking_tensor.float() # [T, C, H, W]
|
| 119 |
tracking_maps = tracking_maps.to(device=self.device, dtype=dtype)
|
|
@@ -167,60 +169,9 @@ class DiffusionAsShaderPipeline:
|
|
| 167 |
|
| 168 |
def _set_camera_motion(self, camera_motion):
|
| 169 |
self.camera_motion = camera_motion
|
| 170 |
-
|
| 171 |
-
def _get_intr(self, fov, H=480, W=720):
|
| 172 |
-
fov_rad = math.radians(fov)
|
| 173 |
-
focal_length = (W / 2) / math.tan(fov_rad / 2)
|
| 174 |
-
|
| 175 |
-
cx = W / 2
|
| 176 |
-
cy = H / 2
|
| 177 |
-
|
| 178 |
-
intr = torch.tensor([
|
| 179 |
-
[focal_length, 0, cx],
|
| 180 |
-
[0, focal_length, cy],
|
| 181 |
-
[0, 0, 1]
|
| 182 |
-
], dtype=torch.float32)
|
| 183 |
-
|
| 184 |
-
return intr
|
| 185 |
-
|
| 186 |
-
@spaces.GPU
|
| 187 |
-
def _apply_poses(self, pts, intr, poses):
|
| 188 |
-
"""
|
| 189 |
-
Args:
|
| 190 |
-
pts (torch.Tensor): pointclouds coordinates [T, N, 3]
|
| 191 |
-
intr (torch.Tensor): camera intrinsics [T, 3, 3]
|
| 192 |
-
poses (numpy.ndarray): camera poses [T, 4, 4]
|
| 193 |
-
"""
|
| 194 |
-
poses = torch.from_numpy(poses).float().to(self.device)
|
| 195 |
-
|
| 196 |
-
T, N, _ = pts.shape
|
| 197 |
-
ones = torch.ones(T, N, 1, device=self.device, dtype=torch.float)
|
| 198 |
-
pts_hom = torch.cat([pts[:, :, :2], ones], dim=-1) # (T, N, 3)
|
| 199 |
-
pts_cam = torch.bmm(pts_hom, torch.linalg.inv(intr).transpose(1, 2)) # (T, N, 3)
|
| 200 |
-
pts_cam[:,:, :3] /= pts[:, :, 2:3]
|
| 201 |
-
|
| 202 |
-
# to homogeneous
|
| 203 |
-
pts_cam = torch.cat([pts_cam, ones], dim=-1) # (T, N, 4)
|
| 204 |
-
|
| 205 |
-
if poses.shape[0] == 1:
|
| 206 |
-
poses = poses.repeat(T, 1, 1)
|
| 207 |
-
elif poses.shape[0] != T:
|
| 208 |
-
raise ValueError(f"Poses length ({poses.shape[0]}) must match sequence length ({T})")
|
| 209 |
-
|
| 210 |
-
pts_world = torch.bmm(pts_cam, poses.transpose(1, 2))[:, :, :3] # (T, N, 3)
|
| 211 |
-
|
| 212 |
-
pts_proj = torch.bmm(pts_world, intr.transpose(1, 2)) # (T, N, 3)
|
| 213 |
-
pts_proj[:, :, :2] /= pts_proj[:, :, 2:3]
|
| 214 |
-
|
| 215 |
-
return pts_proj
|
| 216 |
-
|
| 217 |
-
def apply_traj_on_tracking(self, pred_tracks, camera_motion=None, fov=55, frame_num=49):
|
| 218 |
-
intr = self._get_intr(fov).unsqueeze(0).repeat(frame_num, 1, 1).to(self.device)
|
| 219 |
-
tracking_pts = self._apply_poses(pred_tracks.squeeze(), intr, camera_motion).unsqueeze(0)
|
| 220 |
-
return tracking_pts
|
| 221 |
|
| 222 |
##============= SpatialTracker =============##
|
| 223 |
-
|
| 224 |
def generate_tracking_spatracker(self, video_tensor, density=70):
|
| 225 |
"""Generate tracking video
|
| 226 |
|
|
@@ -233,7 +184,7 @@ class DiffusionAsShaderPipeline:
|
|
| 233 |
print("Loading tracking models...")
|
| 234 |
# Load tracking model
|
| 235 |
tracker = SpaTrackerPredictor(
|
| 236 |
-
checkpoint=os.path.join(project_root, 'checkpoints/
|
| 237 |
interp_shape=(384, 576),
|
| 238 |
seq_length=12
|
| 239 |
).to(self.device)
|
|
@@ -268,14 +219,13 @@ class DiffusionAsShaderPipeline:
|
|
| 268 |
progressive_tracking=False
|
| 269 |
)
|
| 270 |
|
| 271 |
-
return pred_tracks, pred_visibility, T_Firsts
|
| 272 |
|
| 273 |
finally:
|
| 274 |
# Clean up GPU memory
|
| 275 |
del tracker, self.depth_preprocessor
|
| 276 |
torch.cuda.empty_cache()
|
| 277 |
|
| 278 |
-
@spaces.GPU
|
| 279 |
def visualize_tracking_spatracker(self, video, pred_tracks, pred_visibility, T_Firsts, save_tracking=True):
|
| 280 |
video = video.unsqueeze(0).to(self.device)
|
| 281 |
vis = Visualizer(save_dir=self.output_dir, grayscale=False, fps=24, pad_value=0)
|
|
@@ -365,7 +315,6 @@ class DiffusionAsShaderPipeline:
|
|
| 365 |
outline=tuple(color),
|
| 366 |
)
|
| 367 |
|
| 368 |
-
@spaces.GPU
|
| 369 |
def visualize_tracking_moge(self, points, mask, save_tracking=True):
|
| 370 |
"""Visualize tracking results from MoGe model
|
| 371 |
|
|
@@ -399,8 +348,6 @@ class DiffusionAsShaderPipeline:
|
|
| 399 |
normalized_z = np.clip((inv_z - p2) / (p98 - p2), 0, 1)
|
| 400 |
colors[:, :, 2] = (normalized_z * 255).astype(np.uint8)
|
| 401 |
colors = colors.astype(np.uint8)
|
| 402 |
-
# colors = colors * mask[..., None]
|
| 403 |
-
# points = points * mask[None, :, :, None]
|
| 404 |
|
| 405 |
points = points.reshape(T, -1, 3)
|
| 406 |
colors = colors.reshape(-1, 3)
|
|
@@ -408,7 +355,7 @@ class DiffusionAsShaderPipeline:
|
|
| 408 |
# Initialize list to store frames
|
| 409 |
frames = []
|
| 410 |
|
| 411 |
-
for i, pts_i in enumerate(tqdm(points)):
|
| 412 |
pixels, depths = pts_i[..., :2], pts_i[..., 2]
|
| 413 |
pixels[..., 0] = pixels[..., 0] * W
|
| 414 |
pixels[..., 1] = pixels[..., 1] * H
|
|
@@ -451,8 +398,178 @@ class DiffusionAsShaderPipeline:
|
|
| 451 |
tracking_path = None
|
| 452 |
|
| 453 |
return tracking_path, tracking_video
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
|
| 455 |
-
@spaces.GPU(duration=240)
|
| 456 |
def apply_tracking(self, video_tensor, fps=8, tracking_tensor=None, img_cond_tensor=None, prompt=None, checkpoint_path=None):
|
| 457 |
"""Generate final video with motion transfer
|
| 458 |
|
|
@@ -478,7 +595,7 @@ class DiffusionAsShaderPipeline:
|
|
| 478 |
tracking_tensor=tracking_tensor,
|
| 479 |
image_tensor=img_cond_tensor,
|
| 480 |
output_path=final_output,
|
| 481 |
-
num_inference_steps=
|
| 482 |
guidance_scale=6.0,
|
| 483 |
dtype=torch.bfloat16,
|
| 484 |
fps=self.fps
|
|
@@ -493,7 +610,6 @@ class DiffusionAsShaderPipeline:
|
|
| 493 |
"""
|
| 494 |
self.object_motion = motion_type
|
| 495 |
|
| 496 |
-
@spaces.GPU(duration=120)
|
| 497 |
class FirstFrameRepainter:
|
| 498 |
def __init__(self, gpu_id=0, output_dir='outputs'):
|
| 499 |
"""Initialize FirstFrameRepainter
|
|
@@ -506,8 +622,7 @@ class FirstFrameRepainter:
|
|
| 506 |
self.output_dir = output_dir
|
| 507 |
self.max_depth = 65.0
|
| 508 |
os.makedirs(output_dir, exist_ok=True)
|
| 509 |
-
|
| 510 |
-
@spaces.GPU(duration=120)
|
| 511 |
def repaint(self, image_tensor, prompt, depth_path=None, method="dav"):
|
| 512 |
"""Repaint first frame using Flux
|
| 513 |
|
|
@@ -599,48 +714,158 @@ class CameraMotionGenerator:
|
|
| 599 |
fx = fy = (W / 2) / math.tan(fov_rad / 2)
|
| 600 |
|
| 601 |
self.intr[0, 0] = fx
|
| 602 |
-
self.intr[1, 1] = fy
|
|
|
|
|
|
|
| 603 |
|
| 604 |
-
def
|
| 605 |
"""
|
|
|
|
|
|
|
| 606 |
Args:
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
|
|
|
|
|
|
|
|
|
| 610 |
"""
|
| 611 |
-
if isinstance(
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 623 |
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
raise ValueError(f"Poses length ({poses.shape[0]}) must match sequence length ({T})")
|
| 628 |
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 633 |
|
| 634 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 635 |
|
| 636 |
-
def
|
| 637 |
if isinstance(poses, np.ndarray):
|
| 638 |
poses = torch.from_numpy(poses)
|
| 639 |
assert poses.shape[0] == self.frame_num
|
| 640 |
poses = poses.to(torch.float32).to(self.device)
|
| 641 |
T, N, _ = pts.shape # (T, N, 3)
|
| 642 |
intr = self.intr.unsqueeze(0).repeat(self.frame_num, 1, 1)
|
| 643 |
-
# Step 1: 扩展点的维度,使其变成 (T, N, 4),最后一维填充1 (齐次坐标)
|
| 644 |
ones = torch.ones((T, N, 1), device=self.device, dtype=pts.dtype)
|
| 645 |
points_world_h = torch.cat([pts, ones], dim=-1)
|
| 646 |
points_camera_h = torch.bmm(poses, points_world_h.permute(0, 2, 1))
|
|
@@ -649,22 +874,21 @@ class CameraMotionGenerator:
|
|
| 649 |
points_image_h = torch.bmm(points_camera, intr.permute(0, 2, 1))
|
| 650 |
|
| 651 |
uv = points_image_h[:, :, :2] / points_image_h[:, :, 2:3]
|
| 652 |
-
|
| 653 |
-
# Step 5: 提取深度 (Z) 并拼接
|
| 654 |
depth = points_camera[:, :, 2:3] # (T, N, 1)
|
| 655 |
uvd = torch.cat([uv, depth], dim=-1) # (T, N, 3)
|
| 656 |
|
| 657 |
-
return uvd
|
| 658 |
-
|
| 659 |
-
def apply_motion_on_pts(self, pts, camera_motion):
|
| 660 |
-
tracking_pts = self._apply_poses(pts.squeeze(), camera_motion).unsqueeze(0)
|
| 661 |
-
return tracking_pts
|
| 662 |
|
| 663 |
def set_intr(self, K):
|
| 664 |
if isinstance(K, np.ndarray):
|
| 665 |
K = torch.from_numpy(K)
|
| 666 |
self.intr = K.to(self.device)
|
| 667 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 668 |
def rot_poses(self, angle, axis='y'):
|
| 669 |
"""Generate a single rotation matrix
|
| 670 |
|
|
@@ -783,26 +1007,6 @@ class CameraMotionGenerator:
|
|
| 783 |
camera_poses = np.concatenate(cam_poses, axis=0)
|
| 784 |
return torch.from_numpy(camera_poses).to(self.device)
|
| 785 |
|
| 786 |
-
def rot(self, pts, angle, axis):
|
| 787 |
-
"""
|
| 788 |
-
pts: torch.Tensor, (T, N, 2)
|
| 789 |
-
"""
|
| 790 |
-
rot_mats = self.rot_poses(angle, axis)
|
| 791 |
-
pts = self.apply_motion_on_pts(pts, rot_mats)
|
| 792 |
-
return pts
|
| 793 |
-
|
| 794 |
-
def trans(self, pts, dx, dy, dz):
|
| 795 |
-
if pts.shape[-1] != 3:
|
| 796 |
-
raise ValueError("points should be in the 3d coordinate.")
|
| 797 |
-
trans_mats = self.trans_poses(dx, dy, dz)
|
| 798 |
-
pts = self.apply_motion_on_pts(pts, trans_mats)
|
| 799 |
-
return pts
|
| 800 |
-
|
| 801 |
-
def spiral(self, pts, radius):
|
| 802 |
-
spiral_poses = self.spiral_poses(radius)
|
| 803 |
-
pts = self.apply_motion_on_pts(pts, spiral_poses)
|
| 804 |
-
return pts
|
| 805 |
-
|
| 806 |
def get_default_motion(self):
|
| 807 |
"""Parse motion parameters and generate corresponding motion matrices
|
| 808 |
|
|
@@ -820,6 +1024,7 @@ class CameraMotionGenerator:
|
|
| 820 |
- if not specified, defaults to 0-49
|
| 821 |
- frames after end_frame will maintain the final transformation
|
| 822 |
- for combined transformations, they are applied in sequence
|
|
|
|
| 823 |
|
| 824 |
Returns:
|
| 825 |
torch.Tensor: Motion matrices [num_frames, 4, 4]
|
|
|
|
| 22 |
from models.cogvideox_tracking import CogVideoXImageToVideoPipelineTracking
|
| 23 |
|
| 24 |
from submodules.MoGe.moge.model import MoGeModel
|
| 25 |
+
|
| 26 |
from image_gen_aux import DepthPreprocessor
|
| 27 |
from moviepy.editor import ImageSequenceClip
|
|
|
|
| 28 |
|
| 29 |
class DiffusionAsShaderPipeline:
|
| 30 |
def __init__(self, gpu_id=0, output_dir='outputs'):
|
|
|
|
| 45 |
# device
|
| 46 |
self.device = f"cuda:{gpu_id}"
|
| 47 |
torch.cuda.set_device(gpu_id)
|
| 48 |
+
self.dtype = torch.bfloat16
|
| 49 |
|
| 50 |
# files
|
| 51 |
self.output_dir = output_dir
|
|
|
|
| 57 |
transforms.ToTensor()
|
| 58 |
])
|
| 59 |
|
|
|
|
| 60 |
@torch.no_grad()
|
| 61 |
def _infer(
|
| 62 |
self,
|
|
|
|
| 65 |
tracking_tensor: torch.Tensor = None,
|
| 66 |
image_tensor: torch.Tensor = None, # [C,H,W] in range [0,1]
|
| 67 |
output_path: str = "./output.mp4",
|
| 68 |
+
num_inference_steps: int = 25,
|
| 69 |
guidance_scale: float = 6.0,
|
| 70 |
num_videos_per_prompt: int = 1,
|
| 71 |
dtype: torch.dtype = torch.bfloat16,
|
|
|
|
| 114 |
pipe.text_encoder.eval()
|
| 115 |
pipe.vae.eval()
|
| 116 |
|
| 117 |
+
self.dtype = dtype
|
| 118 |
+
|
| 119 |
# Process tracking tensor
|
| 120 |
tracking_maps = tracking_tensor.float() # [T, C, H, W]
|
| 121 |
tracking_maps = tracking_maps.to(device=self.device, dtype=dtype)
|
|
|
|
| 169 |
|
| 170 |
def _set_camera_motion(self, camera_motion):
|
| 171 |
self.camera_motion = camera_motion
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
##============= SpatialTracker =============##
|
| 174 |
+
|
| 175 |
def generate_tracking_spatracker(self, video_tensor, density=70):
|
| 176 |
"""Generate tracking video
|
| 177 |
|
|
|
|
| 184 |
print("Loading tracking models...")
|
| 185 |
# Load tracking model
|
| 186 |
tracker = SpaTrackerPredictor(
|
| 187 |
+
checkpoint=os.path.join(project_root, 'checkpoints/spaT_final.pth'),
|
| 188 |
interp_shape=(384, 576),
|
| 189 |
seq_length=12
|
| 190 |
).to(self.device)
|
|
|
|
| 219 |
progressive_tracking=False
|
| 220 |
)
|
| 221 |
|
| 222 |
+
return pred_tracks.squeeze(0), pred_visibility.squeeze(0), T_Firsts
|
| 223 |
|
| 224 |
finally:
|
| 225 |
# Clean up GPU memory
|
| 226 |
del tracker, self.depth_preprocessor
|
| 227 |
torch.cuda.empty_cache()
|
| 228 |
|
|
|
|
| 229 |
def visualize_tracking_spatracker(self, video, pred_tracks, pred_visibility, T_Firsts, save_tracking=True):
|
| 230 |
video = video.unsqueeze(0).to(self.device)
|
| 231 |
vis = Visualizer(save_dir=self.output_dir, grayscale=False, fps=24, pad_value=0)
|
|
|
|
| 315 |
outline=tuple(color),
|
| 316 |
)
|
| 317 |
|
|
|
|
| 318 |
def visualize_tracking_moge(self, points, mask, save_tracking=True):
|
| 319 |
"""Visualize tracking results from MoGe model
|
| 320 |
|
|
|
|
| 348 |
normalized_z = np.clip((inv_z - p2) / (p98 - p2), 0, 1)
|
| 349 |
colors[:, :, 2] = (normalized_z * 255).astype(np.uint8)
|
| 350 |
colors = colors.astype(np.uint8)
|
|
|
|
|
|
|
| 351 |
|
| 352 |
points = points.reshape(T, -1, 3)
|
| 353 |
colors = colors.reshape(-1, 3)
|
|
|
|
| 355 |
# Initialize list to store frames
|
| 356 |
frames = []
|
| 357 |
|
| 358 |
+
for i, pts_i in enumerate(tqdm(points, desc="rendering frames")):
|
| 359 |
pixels, depths = pts_i[..., :2], pts_i[..., 2]
|
| 360 |
pixels[..., 0] = pixels[..., 0] * W
|
| 361 |
pixels[..., 1] = pixels[..., 1] * H
|
|
|
|
| 398 |
tracking_path = None
|
| 399 |
|
| 400 |
return tracking_path, tracking_video
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
##============= CoTracker =============##
|
| 404 |
+
|
| 405 |
+
def generate_tracking_cotracker(self, video_tensor, density=70):
|
| 406 |
+
"""Generate tracking video
|
| 407 |
+
|
| 408 |
+
Args:
|
| 409 |
+
video_tensor (torch.Tensor): Input video tensor
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
tuple: (pred_tracks, pred_visibility)
|
| 413 |
+
- pred_tracks (torch.Tensor): Tracking points with depth [T, N, 3]
|
| 414 |
+
- pred_visibility (torch.Tensor): Visibility mask [T, N, 1]
|
| 415 |
+
"""
|
| 416 |
+
# Generate tracking points
|
| 417 |
+
cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker3_offline").to(self.device)
|
| 418 |
+
|
| 419 |
+
# Load depth model
|
| 420 |
+
if not hasattr(self, 'depth_preprocessor') or self.depth_preprocessor is None:
|
| 421 |
+
self.depth_preprocessor = DepthPreprocessor.from_pretrained("Intel/zoedepth-nyu-kitti")
|
| 422 |
+
self.depth_preprocessor.to(self.device)
|
| 423 |
+
|
| 424 |
+
try:
|
| 425 |
+
video = video_tensor.unsqueeze(0).to(self.device)
|
| 426 |
+
|
| 427 |
+
# Process all frames to get depth maps
|
| 428 |
+
video_depths = []
|
| 429 |
+
for i in tqdm(range(video_tensor.shape[0]), desc="estimating depth"):
|
| 430 |
+
frame = (video_tensor[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
| 431 |
+
depth = self.depth_preprocessor(Image.fromarray(frame))[0]
|
| 432 |
+
depth_tensor = transforms.ToTensor()(depth) # [1, H, W]
|
| 433 |
+
video_depths.append(depth_tensor)
|
| 434 |
+
|
| 435 |
+
video_depth = torch.stack(video_depths, dim=0).to(self.device) # [T, 1, H, W]
|
| 436 |
+
|
| 437 |
+
# Get tracking points and visibility
|
| 438 |
+
print("tracking...")
|
| 439 |
+
pred_tracks, pred_visibility = cotracker(video, grid_size=density) # B T N 2, B T N 1
|
| 440 |
+
|
| 441 |
+
# Extract dimensions
|
| 442 |
+
B, T, N, _ = pred_tracks.shape
|
| 443 |
+
H, W = video_depth.shape[2], video_depth.shape[3]
|
| 444 |
+
|
| 445 |
+
# Create output tensor with depth
|
| 446 |
+
pred_tracks_with_depth = torch.zeros((B, T, N, 3), device=self.device)
|
| 447 |
+
pred_tracks_with_depth[:, :, :, :2] = pred_tracks # Copy x,y coordinates
|
| 448 |
+
|
| 449 |
+
# Vectorized approach to get depths for all points
|
| 450 |
+
# Reshape pred_tracks to process all batches and frames at once
|
| 451 |
+
flat_tracks = pred_tracks.reshape(B*T, N, 2)
|
| 452 |
+
|
| 453 |
+
# Clamp coordinates to valid image bounds
|
| 454 |
+
x_coords = flat_tracks[:, :, 0].clamp(0, W-1).long() # [B*T, N]
|
| 455 |
+
y_coords = flat_tracks[:, :, 1].clamp(0, H-1).long() # [B*T, N]
|
| 456 |
+
|
| 457 |
+
# Get depths for all points at once
|
| 458 |
+
# For each point in the flattened batch, get its depth from the corresponding frame
|
| 459 |
+
depths = torch.zeros((B*T, N), device=self.device)
|
| 460 |
+
for bt in range(B*T):
|
| 461 |
+
t = bt % T # Time index
|
| 462 |
+
depths[bt] = video_depth[t, 0, y_coords[bt], x_coords[bt]]
|
| 463 |
+
|
| 464 |
+
# Reshape depths back to [B, T, N] and assign to output tensor
|
| 465 |
+
pred_tracks_with_depth[:, :, :, 2] = depths.reshape(B, T, N)
|
| 466 |
+
|
| 467 |
+
return pred_tracks_with_depth.squeeze(0), pred_visibility.squeeze(0)
|
| 468 |
+
|
| 469 |
+
finally:
|
| 470 |
+
del cotracker
|
| 471 |
+
torch.cuda.empty_cache()
|
| 472 |
+
|
| 473 |
+
def visualize_tracking_cotracker(self, points, vis_mask=None, save_tracking=True, point_wise=4, video_size=(480, 720)):
|
| 474 |
+
"""Visualize tracking results from CoTracker
|
| 475 |
+
|
| 476 |
+
Args:
|
| 477 |
+
points (torch.Tensor): Points array of shape [T, N, 3]
|
| 478 |
+
vis_mask (torch.Tensor): Visibility mask of shape [T, N, 1]
|
| 479 |
+
save_tracking (bool): Whether to save tracking video
|
| 480 |
+
point_wise (int): Size of points in visualization
|
| 481 |
+
video_size (tuple): Render size (height, width)
|
| 482 |
+
|
| 483 |
+
Returns:
|
| 484 |
+
tuple: (tracking_path, tracking_video)
|
| 485 |
+
"""
|
| 486 |
+
# Move tensors to CPU and convert to numpy
|
| 487 |
+
if isinstance(points, torch.Tensor):
|
| 488 |
+
points = points.detach().cpu().numpy()
|
| 489 |
+
|
| 490 |
+
if vis_mask is not None and isinstance(vis_mask, torch.Tensor):
|
| 491 |
+
vis_mask = vis_mask.detach().cpu().numpy()
|
| 492 |
+
# Reshape if needed
|
| 493 |
+
if vis_mask.ndim == 3 and vis_mask.shape[2] == 1:
|
| 494 |
+
vis_mask = vis_mask.squeeze(-1)
|
| 495 |
+
|
| 496 |
+
T, N, _ = points.shape
|
| 497 |
+
H, W = video_size
|
| 498 |
+
|
| 499 |
+
if vis_mask is None:
|
| 500 |
+
vis_mask = np.ones((T, N), dtype=bool)
|
| 501 |
+
|
| 502 |
+
colors = np.zeros((N, 3), dtype=np.uint8)
|
| 503 |
+
|
| 504 |
+
first_frame_pts = points[0]
|
| 505 |
+
|
| 506 |
+
u_min, u_max = 0, W
|
| 507 |
+
u_normalized = np.clip((first_frame_pts[:, 0] - u_min) / (u_max - u_min), 0, 1)
|
| 508 |
+
colors[:, 0] = (u_normalized * 255).astype(np.uint8)
|
| 509 |
+
|
| 510 |
+
v_min, v_max = 0, H
|
| 511 |
+
v_normalized = np.clip((first_frame_pts[:, 1] - v_min) / (v_max - v_min), 0, 1)
|
| 512 |
+
colors[:, 1] = (v_normalized * 255).astype(np.uint8)
|
| 513 |
+
|
| 514 |
+
z_values = first_frame_pts[:, 2]
|
| 515 |
+
if np.all(z_values == 0):
|
| 516 |
+
colors[:, 2] = np.random.randint(0, 256, N, dtype=np.uint8)
|
| 517 |
+
else:
|
| 518 |
+
inv_z = 1 / (z_values + 1e-10)
|
| 519 |
+
p2 = np.percentile(inv_z, 2)
|
| 520 |
+
p98 = np.percentile(inv_z, 98)
|
| 521 |
+
normalized_z = np.clip((inv_z - p2) / (p98 - p2 + 1e-10), 0, 1)
|
| 522 |
+
colors[:, 2] = (normalized_z * 255).astype(np.uint8)
|
| 523 |
+
|
| 524 |
+
frames = []
|
| 525 |
+
|
| 526 |
+
for i in tqdm(range(T), desc="rendering frames"):
|
| 527 |
+
pts_i = points[i]
|
| 528 |
+
|
| 529 |
+
visibility = vis_mask[i]
|
| 530 |
+
|
| 531 |
+
pixels, depths = pts_i[visibility, :2], pts_i[visibility, 2]
|
| 532 |
+
pixels = pixels.astype(int)
|
| 533 |
+
|
| 534 |
+
in_frame = self.valid_mask(pixels, W, H)
|
| 535 |
+
pixels = pixels[in_frame]
|
| 536 |
+
depths = depths[in_frame]
|
| 537 |
+
frame_rgb = colors[visibility][in_frame]
|
| 538 |
+
|
| 539 |
+
img = Image.fromarray(np.zeros((H, W, 3), dtype=np.uint8), mode="RGB")
|
| 540 |
+
|
| 541 |
+
sorted_pixels, _, sort_index = self.sort_points_by_depth(pixels, depths)
|
| 542 |
+
sorted_rgb = frame_rgb[sort_index]
|
| 543 |
+
|
| 544 |
+
for j in range(sorted_pixels.shape[0]):
|
| 545 |
+
self.draw_rectangle(
|
| 546 |
+
img,
|
| 547 |
+
coord=(sorted_pixels[j, 0], sorted_pixels[j, 1]),
|
| 548 |
+
side_length=point_wise,
|
| 549 |
+
color=sorted_rgb[j],
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
frames.append(np.array(img))
|
| 553 |
+
|
| 554 |
+
# Convert frames to video tensor in range [0,1]
|
| 555 |
+
tracking_video = torch.from_numpy(np.stack(frames)).permute(0, 3, 1, 2).float() / 255.0
|
| 556 |
+
|
| 557 |
+
tracking_path = None
|
| 558 |
+
if save_tracking:
|
| 559 |
+
try:
|
| 560 |
+
tracking_path = os.path.join(self.output_dir, "tracking_video_cotracker.mp4")
|
| 561 |
+
# Convert back to uint8 for saving
|
| 562 |
+
uint8_frames = [frame.astype(np.uint8) for frame in frames]
|
| 563 |
+
clip = ImageSequenceClip(uint8_frames, fps=self.fps)
|
| 564 |
+
clip.write_videofile(tracking_path, codec="libx264", fps=self.fps, logger=None)
|
| 565 |
+
print(f"Video saved to {tracking_path}")
|
| 566 |
+
except Exception as e:
|
| 567 |
+
print(f"Warning: Failed to save tracking video: {e}")
|
| 568 |
+
tracking_path = None
|
| 569 |
+
|
| 570 |
+
return tracking_path, tracking_video
|
| 571 |
+
|
| 572 |
|
|
|
|
| 573 |
def apply_tracking(self, video_tensor, fps=8, tracking_tensor=None, img_cond_tensor=None, prompt=None, checkpoint_path=None):
|
| 574 |
"""Generate final video with motion transfer
|
| 575 |
|
|
|
|
| 595 |
tracking_tensor=tracking_tensor,
|
| 596 |
image_tensor=img_cond_tensor,
|
| 597 |
output_path=final_output,
|
| 598 |
+
num_inference_steps=25,
|
| 599 |
guidance_scale=6.0,
|
| 600 |
dtype=torch.bfloat16,
|
| 601 |
fps=self.fps
|
|
|
|
| 610 |
"""
|
| 611 |
self.object_motion = motion_type
|
| 612 |
|
|
|
|
| 613 |
class FirstFrameRepainter:
|
| 614 |
def __init__(self, gpu_id=0, output_dir='outputs'):
|
| 615 |
"""Initialize FirstFrameRepainter
|
|
|
|
| 622 |
self.output_dir = output_dir
|
| 623 |
self.max_depth = 65.0
|
| 624 |
os.makedirs(output_dir, exist_ok=True)
|
| 625 |
+
|
|
|
|
| 626 |
def repaint(self, image_tensor, prompt, depth_path=None, method="dav"):
|
| 627 |
"""Repaint first frame using Flux
|
| 628 |
|
|
|
|
| 714 |
fx = fy = (W / 2) / math.tan(fov_rad / 2)
|
| 715 |
|
| 716 |
self.intr[0, 0] = fx
|
| 717 |
+
self.intr[1, 1] = fy
|
| 718 |
+
|
| 719 |
+
self.extr = torch.eye(4, device=device)
|
| 720 |
|
| 721 |
+
def s2w_vggt(self, points, extrinsics, intrinsics):
|
| 722 |
"""
|
| 723 |
+
Transform points from pixel coordinates to world coordinates
|
| 724 |
+
|
| 725 |
Args:
|
| 726 |
+
points: Point cloud data of shape [T, N, 3] in uvz format
|
| 727 |
+
extrinsics: Camera extrinsic matrices [B, T, 3, 4] or [T, 3, 4]
|
| 728 |
+
intrinsics: Camera intrinsic matrices [B, T, 3, 3] or [T, 3, 3]
|
| 729 |
+
|
| 730 |
+
Returns:
|
| 731 |
+
world_points: Point cloud in world coordinates [T, N, 3]
|
| 732 |
"""
|
| 733 |
+
if isinstance(points, torch.Tensor):
|
| 734 |
+
points = points.detach().cpu().numpy()
|
| 735 |
+
|
| 736 |
+
if isinstance(extrinsics, torch.Tensor):
|
| 737 |
+
extrinsics = extrinsics.detach().cpu().numpy()
|
| 738 |
+
# Handle batch dimension
|
| 739 |
+
if extrinsics.ndim == 4: # [B, T, 3, 4]
|
| 740 |
+
extrinsics = extrinsics[0] # Take first batch
|
| 741 |
+
|
| 742 |
+
if isinstance(intrinsics, torch.Tensor):
|
| 743 |
+
intrinsics = intrinsics.detach().cpu().numpy()
|
| 744 |
+
# Handle batch dimension
|
| 745 |
+
if intrinsics.ndim == 4: # [B, T, 3, 3]
|
| 746 |
+
intrinsics = intrinsics[0] # Take first batch
|
| 747 |
+
|
| 748 |
+
T, N, _ = points.shape
|
| 749 |
+
world_points = np.zeros_like(points)
|
| 750 |
|
| 751 |
+
# Extract uvz coordinates
|
| 752 |
+
uvz = points
|
| 753 |
+
valid_mask = uvz[..., 2] > 0
|
|
|
|
| 754 |
|
| 755 |
+
# Create homogeneous coordinates [u, v, 1]
|
| 756 |
+
uv_homogeneous = np.concatenate([uvz[..., :2], np.ones((T, N, 1))], axis=-1)
|
| 757 |
+
|
| 758 |
+
# Transform from pixel to camera coordinates
|
| 759 |
+
for i in range(T):
|
| 760 |
+
K = intrinsics[i]
|
| 761 |
+
K_inv = np.linalg.inv(K)
|
| 762 |
+
|
| 763 |
+
R = extrinsics[i, :, :3]
|
| 764 |
+
t = extrinsics[i, :, 3]
|
| 765 |
+
|
| 766 |
+
R_inv = np.linalg.inv(R)
|
| 767 |
+
|
| 768 |
+
valid_indices = np.where(valid_mask[i])[0]
|
| 769 |
+
|
| 770 |
+
if len(valid_indices) > 0:
|
| 771 |
+
valid_uv = uv_homogeneous[i, valid_indices]
|
| 772 |
+
valid_z = uvz[i, valid_indices, 2]
|
| 773 |
+
|
| 774 |
+
valid_xyz_camera = valid_uv @ K_inv.T
|
| 775 |
+
valid_xyz_camera = valid_xyz_camera * valid_z[:, np.newaxis]
|
| 776 |
+
|
| 777 |
+
# Transform from camera to world coordinates: X_world = R^-1 * (X_camera - t)
|
| 778 |
+
valid_world_points = (valid_xyz_camera - t) @ R_inv.T
|
| 779 |
+
|
| 780 |
+
world_points[i, valid_indices] = valid_world_points
|
| 781 |
+
|
| 782 |
+
return world_points
|
| 783 |
|
| 784 |
+
def w2s_vggt(self, world_points, extrinsics, intrinsics, poses=None):
|
| 785 |
+
"""
|
| 786 |
+
Project points from world coordinates to camera view
|
| 787 |
+
|
| 788 |
+
Args:
|
| 789 |
+
world_points: Point cloud in world coordinates [T, N, 3]
|
| 790 |
+
extrinsics: Original camera extrinsic matrices [B, T, 3, 4] or [T, 3, 4]
|
| 791 |
+
intrinsics: Camera intrinsic matrices [B, T, 3, 3] or [T, 3, 3]
|
| 792 |
+
poses: Camera pose matrices [T, 4, 4], if None use first frame extrinsics
|
| 793 |
+
|
| 794 |
+
Returns:
|
| 795 |
+
camera_points: Point cloud in camera coordinates [T, N, 3] in uvz format
|
| 796 |
+
"""
|
| 797 |
+
if isinstance(world_points, torch.Tensor):
|
| 798 |
+
world_points = world_points.detach().cpu().numpy()
|
| 799 |
+
|
| 800 |
+
if isinstance(extrinsics, torch.Tensor):
|
| 801 |
+
extrinsics = extrinsics.detach().cpu().numpy()
|
| 802 |
+
if extrinsics.ndim == 4:
|
| 803 |
+
extrinsics = extrinsics[0]
|
| 804 |
+
|
| 805 |
+
if isinstance(intrinsics, torch.Tensor):
|
| 806 |
+
intrinsics = intrinsics.detach().cpu().numpy()
|
| 807 |
+
if intrinsics.ndim == 4:
|
| 808 |
+
intrinsics = intrinsics[0]
|
| 809 |
+
|
| 810 |
+
T, N, _ = world_points.shape
|
| 811 |
+
|
| 812 |
+
# If no poses provided, use first frame extrinsics
|
| 813 |
+
if poses is None:
|
| 814 |
+
pose1 = np.eye(4)
|
| 815 |
+
pose1[:3, :3] = extrinsics[0, :, :3]
|
| 816 |
+
pose1[:3, 3] = extrinsics[0, :, 3]
|
| 817 |
+
|
| 818 |
+
camera_poses = np.tile(pose1[np.newaxis, :, :], (T, 1, 1))
|
| 819 |
+
else:
|
| 820 |
+
if isinstance(poses, torch.Tensor):
|
| 821 |
+
camera_poses = poses.cpu().numpy()
|
| 822 |
+
else:
|
| 823 |
+
camera_poses = poses
|
| 824 |
+
|
| 825 |
+
# Scale translation by 1/5
|
| 826 |
+
scaled_poses = camera_poses.copy()
|
| 827 |
+
scaled_poses[:, :3, 3] = camera_poses[:, :3, 3] / 5.0
|
| 828 |
+
camera_poses = scaled_poses
|
| 829 |
+
|
| 830 |
+
# Add homogeneous coordinates
|
| 831 |
+
ones = np.ones([T, N, 1])
|
| 832 |
+
world_points_hom = np.concatenate([world_points, ones], axis=-1)
|
| 833 |
+
|
| 834 |
+
# Transform points using batch matrix multiplication
|
| 835 |
+
pts_cam_hom = np.matmul(world_points_hom, np.transpose(camera_poses, (0, 2, 1)))
|
| 836 |
+
pts_cam = pts_cam_hom[..., :3]
|
| 837 |
+
|
| 838 |
+
# Extract depth information
|
| 839 |
+
depths = pts_cam[..., 2:3]
|
| 840 |
+
valid_mask = depths[..., 0] > 0
|
| 841 |
+
|
| 842 |
+
# Normalize coordinates
|
| 843 |
+
normalized_pts = pts_cam / (depths + 1e-10)
|
| 844 |
+
|
| 845 |
+
# Apply intrinsic matrix for projection
|
| 846 |
+
pts_pixel = np.matmul(normalized_pts, np.transpose(intrinsics, (0, 2, 1)))
|
| 847 |
+
|
| 848 |
+
# Extract pixel coordinates
|
| 849 |
+
u = pts_pixel[..., 0:1]
|
| 850 |
+
v = pts_pixel[..., 1:2]
|
| 851 |
+
|
| 852 |
+
# Set invalid points to zero
|
| 853 |
+
u[~valid_mask] = 0
|
| 854 |
+
v[~valid_mask] = 0
|
| 855 |
+
depths[~valid_mask] = 0
|
| 856 |
+
|
| 857 |
+
# Return points in uvz format
|
| 858 |
+
result = np.concatenate([u, v, depths], axis=-1)
|
| 859 |
+
|
| 860 |
+
return torch.from_numpy(result)
|
| 861 |
|
| 862 |
+
def w2s_moge(self, pts, poses):
|
| 863 |
if isinstance(poses, np.ndarray):
|
| 864 |
poses = torch.from_numpy(poses)
|
| 865 |
assert poses.shape[0] == self.frame_num
|
| 866 |
poses = poses.to(torch.float32).to(self.device)
|
| 867 |
T, N, _ = pts.shape # (T, N, 3)
|
| 868 |
intr = self.intr.unsqueeze(0).repeat(self.frame_num, 1, 1)
|
|
|
|
| 869 |
ones = torch.ones((T, N, 1), device=self.device, dtype=pts.dtype)
|
| 870 |
points_world_h = torch.cat([pts, ones], dim=-1)
|
| 871 |
points_camera_h = torch.bmm(poses, points_world_h.permute(0, 2, 1))
|
|
|
|
| 874 |
points_image_h = torch.bmm(points_camera, intr.permute(0, 2, 1))
|
| 875 |
|
| 876 |
uv = points_image_h[:, :, :2] / points_image_h[:, :, 2:3]
|
|
|
|
|
|
|
| 877 |
depth = points_camera[:, :, 2:3] # (T, N, 1)
|
| 878 |
uvd = torch.cat([uv, depth], dim=-1) # (T, N, 3)
|
| 879 |
|
| 880 |
+
return uvd
|
|
|
|
|
|
|
|
|
|
|
|
|
| 881 |
|
| 882 |
def set_intr(self, K):
|
| 883 |
if isinstance(K, np.ndarray):
|
| 884 |
K = torch.from_numpy(K)
|
| 885 |
self.intr = K.to(self.device)
|
| 886 |
|
| 887 |
+
def set_extr(self, extr):
|
| 888 |
+
if isinstance(extr, np.ndarray):
|
| 889 |
+
extr = torch.from_numpy(extr)
|
| 890 |
+
self.extr = extr.to(self.device)
|
| 891 |
+
|
| 892 |
def rot_poses(self, angle, axis='y'):
|
| 893 |
"""Generate a single rotation matrix
|
| 894 |
|
|
|
|
| 1007 |
camera_poses = np.concatenate(cam_poses, axis=0)
|
| 1008 |
return torch.from_numpy(camera_poses).to(self.device)
|
| 1009 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1010 |
def get_default_motion(self):
|
| 1011 |
"""Parse motion parameters and generate corresponding motion matrices
|
| 1012 |
|
|
|
|
| 1024 |
- if not specified, defaults to 0-49
|
| 1025 |
- frames after end_frame will maintain the final transformation
|
| 1026 |
- for combined transformations, they are applied in sequence
|
| 1027 |
+
- moving left, up and zoom out is positive in video
|
| 1028 |
|
| 1029 |
Returns:
|
| 1030 |
torch.Tensor: Motion matrices [num_frames, 4, 4]
|