|  | import argparse | 
					
						
						|  | import json | 
					
						
						|  | import random | 
					
						
						|  | from pathlib import Path | 
					
						
						|  |  | 
					
						
						|  | import imageio | 
					
						
						|  | import numpy as np | 
					
						
						|  | import torch | 
					
						
						|  | from PIL import Image | 
					
						
						|  | from transformers import AutoModel | 
					
						
						|  | from tqdm import tqdm | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | IMAGE_SIZE = (288, 512) | 
					
						
						|  | N_FRAMES_PER_ROUND = 25 | 
					
						
						|  | MAX_NUM_FRAMES = 50 | 
					
						
						|  | N_TOKENS_PER_FRAME = 576 | 
					
						
						|  | TRAJ_TEMPLATE_PATH = Path("./assets/template_trajectory.json") | 
					
						
						|  | PATH_START_ID = 9 | 
					
						
						|  | PATH_POINT_INTERVAL = 10 | 
					
						
						|  | N_ACTION_TOKENS = 6 | 
					
						
						|  | WM_TOKENIZER_COMBINATION = { | 
					
						
						|  | "world_model": "lfq_tokenizer_B_256", | 
					
						
						|  | "world_model_v2": "lfq_tokenizer_B_256_ema", | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | CONDITIONING_FRAMES_DIR = Path("./assets/conditioning_frames") | 
					
						
						|  | CONDITIONING_FRAMES_PATH_LIST = [ | 
					
						
						|  | CONDITIONING_FRAMES_DIR / "001.png", | 
					
						
						|  | CONDITIONING_FRAMES_DIR / "002.png", | 
					
						
						|  | CONDITIONING_FRAMES_DIR / "003.png" | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def set_random_seed(seed: int = 0): | 
					
						
						|  | random.seed(seed) | 
					
						
						|  | np.random.seed(seed) | 
					
						
						|  | torch.manual_seed(seed) | 
					
						
						|  | torch.cuda.manual_seed(seed) | 
					
						
						|  | torch.backends.cudnn.deterministic = True | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def preprocess_image(image: Image.Image, size: tuple[int, int] = (288, 512)) -> torch.Tensor: | 
					
						
						|  | H, W = size | 
					
						
						|  | image = image.convert("RGB") | 
					
						
						|  | image = image.resize((W, H)) | 
					
						
						|  | image_array = np.array(image) | 
					
						
						|  | image_array = (image_array / 127.5 - 1.0).astype(np.float32) | 
					
						
						|  | return torch.from_numpy(image_array).permute(2, 0, 1).unsqueeze(0).float() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def to_np_images(images: torch.Tensor) -> np.ndarray: | 
					
						
						|  | images = images.detach().cpu() | 
					
						
						|  | images = torch.clamp(images, -1., 1.) | 
					
						
						|  | images = (images + 1.) / 2. | 
					
						
						|  | images = images.permute(0, 2, 3, 1).numpy() | 
					
						
						|  | return (255 * images).astype(np.uint8) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def load_images(file_path_list: list[Path], size: tuple[int, int] = (288, 512)) -> torch.Tensor: | 
					
						
						|  | images = [] | 
					
						
						|  | for file_path in file_path_list: | 
					
						
						|  | image = Image.open(file_path) | 
					
						
						|  | image = preprocess_image(image, size) | 
					
						
						|  | images.append(image) | 
					
						
						|  | return torch.cat(images, dim=0) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def save_images_to_mp4(images: np.ndarray, output_path: Path, fps: int = 10): | 
					
						
						|  | writer = imageio.get_writer(output_path, fps=fps) | 
					
						
						|  | for img in images: | 
					
						
						|  | writer.append_data(img) | 
					
						
						|  | writer.close() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def determine_num_rounds(num_frames: int, num_overlapping_frames: int, n_initial_frames: int) -> int: | 
					
						
						|  | n_rounds = (num_frames - n_initial_frames) // (N_FRAMES_PER_ROUND - num_overlapping_frames) | 
					
						
						|  | if (num_frames - n_initial_frames) % (N_FRAMES_PER_ROUND - num_overlapping_frames) > 0: | 
					
						
						|  | n_rounds += 1 | 
					
						
						|  | return n_rounds | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def prepare_action( | 
					
						
						|  | traj_template: dict, | 
					
						
						|  | cmd: str, | 
					
						
						|  | path_start_id: int, | 
					
						
						|  | path_point_interval: int, | 
					
						
						|  | n_action_tokens: int = 5, | 
					
						
						|  | start_index: int = 0, | 
					
						
						|  | n_frames: int = 25 | 
					
						
						|  | ) -> torch.Tensor: | 
					
						
						|  | trajs = traj_template[cmd]["instruction_trajs"] | 
					
						
						|  | actions = [] | 
					
						
						|  | timesteps = np.arange(0.0, 3.0, 0.05) | 
					
						
						|  | for i in range(start_index, start_index + n_frames): | 
					
						
						|  | traj = trajs[i][path_start_id::path_point_interval][:n_action_tokens] | 
					
						
						|  | action = np.array(traj) | 
					
						
						|  | timestep = timesteps[path_start_id::path_point_interval][:n_action_tokens] | 
					
						
						|  | action = np.concatenate([ | 
					
						
						|  | action[:, [1, 0]], | 
					
						
						|  | timestep.reshape(-1, 1) | 
					
						
						|  | ], axis=1) | 
					
						
						|  | actions.append(torch.tensor(action)) | 
					
						
						|  | return torch.cat(actions, dim=0) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | parser = argparse.ArgumentParser() | 
					
						
						|  | parser.add_argument("--seed", type=int, default=0) | 
					
						
						|  | parser.add_argument("--output_dir", type=Path) | 
					
						
						|  | parser.add_argument("--cmd", type=str, default="curving_to_left/curving_to_left_moderate") | 
					
						
						|  | parser.add_argument("--num_frames", type=int, default=25) | 
					
						
						|  | parser.add_argument("--num_overlapping_frames", type=int, default=3) | 
					
						
						|  | parser.add_argument("--model_name", type=str, default="world_model_v2") | 
					
						
						|  | args = parser.parse_args() | 
					
						
						|  |  | 
					
						
						|  | assert args.num_frames <= MAX_NUM_FRAMES, f"`num_frames` should be less than or equal to {MAX_NUM_FRAMES}" | 
					
						
						|  | assert args.num_overlapping_frames < N_FRAMES_PER_ROUND, f"`num_overlapping_frames` should be less than {N_FRAMES_PER_ROUND}" | 
					
						
						|  |  | 
					
						
						|  | set_random_seed(args.seed) | 
					
						
						|  | if args.output_dir is None: | 
					
						
						|  | output_dir = Path(f"./outputs/{args.cmd}") | 
					
						
						|  | else: | 
					
						
						|  | output_dir = args.output_dir | 
					
						
						|  | output_dir.mkdir(parents=True, exist_ok=True) | 
					
						
						|  |  | 
					
						
						|  | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | 
					
						
						|  | tokenizer_name = WM_TOKENIZER_COMBINATION[args.model_name] | 
					
						
						|  | tokenizer = AutoModel.from_pretrained("turing-motors/Terra", subfolder=tokenizer_name, trust_remote_code=True).to(device).eval() | 
					
						
						|  | model = AutoModel.from_pretrained("turing-motors/Terra", subfolder=args.model_name, trust_remote_code=True).to(device).eval() | 
					
						
						|  |  | 
					
						
						|  | conditioning_frames = load_images(CONDITIONING_FRAMES_PATH_LIST, IMAGE_SIZE).to(device) | 
					
						
						|  | with torch.inference_mode(), torch.autocast(device_type="cuda"): | 
					
						
						|  | input_ids = tokenizer.tokenize(conditioning_frames).detach().unsqueeze(0) | 
					
						
						|  |  | 
					
						
						|  | num_rounds = determine_num_rounds(args.num_frames, args.num_overlapping_frames, len(CONDITIONING_FRAMES_PATH_LIST)) | 
					
						
						|  | print(f"Number of generation rounds: {num_rounds}") | 
					
						
						|  |  | 
					
						
						|  | with open(TRAJ_TEMPLATE_PATH) as f: | 
					
						
						|  | traj_template = json.load(f) | 
					
						
						|  |  | 
					
						
						|  | all_outputs = [] | 
					
						
						|  | for round in range(num_rounds): | 
					
						
						|  | start_index = round * (N_FRAMES_PER_ROUND - args.num_overlapping_frames) | 
					
						
						|  | num_frames_for_round = min(N_FRAMES_PER_ROUND, args.num_frames - start_index) | 
					
						
						|  | actions = prepare_action( | 
					
						
						|  | traj_template, args.cmd, PATH_START_ID, PATH_POINT_INTERVAL, N_ACTION_TOKENS, start_index, num_frames_for_round | 
					
						
						|  | ).unsqueeze(0).to(device).float() | 
					
						
						|  | if round == 0: | 
					
						
						|  | num_generated_tokens = N_TOKENS_PER_FRAME * (num_frames_for_round - len(CONDITIONING_FRAMES_PATH_LIST)) | 
					
						
						|  | else: | 
					
						
						|  | num_generated_tokens = N_TOKENS_PER_FRAME * (num_frames_for_round - args.num_overlapping_frames) | 
					
						
						|  | progress_bar = tqdm(total=num_generated_tokens, desc=f"Round {round + 1}") | 
					
						
						|  | with torch.inference_mode(), torch.autocast(device_type="cuda"): | 
					
						
						|  | output_tokens = model.generate( | 
					
						
						|  | input_ids=input_ids, | 
					
						
						|  | actions=actions, | 
					
						
						|  | do_sample=True, | 
					
						
						|  | max_length=N_TOKENS_PER_FRAME * num_frames_for_round, | 
					
						
						|  | temperature=1.0, | 
					
						
						|  | top_p=1.0, | 
					
						
						|  | use_cache=True, | 
					
						
						|  | pad_token_id=None, | 
					
						
						|  | eos_token_id=None, | 
					
						
						|  | progress_bar=progress_bar | 
					
						
						|  | ) | 
					
						
						|  | if round == 0: | 
					
						
						|  | all_outputs.append(output_tokens[0]) | 
					
						
						|  | else: | 
					
						
						|  | all_outputs.append(output_tokens[0, args.num_overlapping_frames * N_TOKENS_PER_FRAME:]) | 
					
						
						|  | input_ids = output_tokens[:, -args.num_overlapping_frames * N_TOKENS_PER_FRAME:] | 
					
						
						|  | progress_bar.close() | 
					
						
						|  |  | 
					
						
						|  | output_ids = torch.cat(all_outputs) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | downsample_ratio = 1 | 
					
						
						|  | for coef in tokenizer.config.encoder_decoder_config["ch_mult"]: | 
					
						
						|  | downsample_ratio *= coef | 
					
						
						|  | h = IMAGE_SIZE[0] // downsample_ratio | 
					
						
						|  | w = IMAGE_SIZE[1] // downsample_ratio | 
					
						
						|  | c = tokenizer.config.encoder_decoder_config["z_channels"] | 
					
						
						|  | latent_shape = (len(output_ids) // 576, h, w, c) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with torch.inference_mode(), torch.autocast(device_type="cuda"): | 
					
						
						|  | reconstructed = tokenizer.decode_tokens(output_ids, latent_shape) | 
					
						
						|  | reconstructed_images = to_np_images(reconstructed) | 
					
						
						|  | save_images_to_mp4(reconstructed_images, output_dir / "generated.mp4", fps=10) | 
					
						
						|  |  |