|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from functools import partial | 
					
						
						|  | from typing import List, Tuple | 
					
						
						|  |  | 
					
						
						|  | import logging | 
					
						
						|  | import numpy as np | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | from einops import rearrange | 
					
						
						|  | from timm.layers import to_2tuple | 
					
						
						|  | from timm.models.vision_transformer import Block | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False): | 
					
						
						|  | """ | 
					
						
						|  | Create 3D sin/cos positional embeddings. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | embed_dim (int): | 
					
						
						|  | Embedding dimension. | 
					
						
						|  | grid_size (tuple[int, int, int] | list[int]): | 
					
						
						|  | The grid depth, height and width. | 
					
						
						|  | add_cls_token (bool, *optional*, defaults to False): | 
					
						
						|  | Whether or not to add a classification (CLS) token. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | (`torch.FloatTensor` of shape (grid_size[0]*grid_size[1]*grid_size[2], embed_dim) or | 
					
						
						|  | (1+grid_size[0]*grid_size[1]*grid_size[2], embed_dim): the position embeddings (with or without cls token) | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | assert embed_dim % 16 == 0 | 
					
						
						|  |  | 
					
						
						|  | t_size, h_size, w_size = grid_size | 
					
						
						|  |  | 
					
						
						|  | w_embed_dim = embed_dim // 16 * 6 | 
					
						
						|  | h_embed_dim = embed_dim // 16 * 6 | 
					
						
						|  | t_embed_dim = embed_dim // 16 * 4 | 
					
						
						|  |  | 
					
						
						|  | w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size)) | 
					
						
						|  | h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size)) | 
					
						
						|  | t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size)) | 
					
						
						|  |  | 
					
						
						|  | w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1)) | 
					
						
						|  | h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1)) | 
					
						
						|  | t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0) | 
					
						
						|  |  | 
					
						
						|  | pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1) | 
					
						
						|  |  | 
					
						
						|  | if add_cls_token: | 
					
						
						|  | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) | 
					
						
						|  | return pos_embed | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | 
					
						
						|  | """ | 
					
						
						|  | embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) | 
					
						
						|  | """ | 
					
						
						|  | if embed_dim % 2 != 0: | 
					
						
						|  | raise ValueError("embed_dim must be even") | 
					
						
						|  |  | 
					
						
						|  | omega = np.arange(embed_dim // 2, dtype=float) | 
					
						
						|  | omega /= embed_dim / 2.0 | 
					
						
						|  | omega = 1.0 / 10000**omega | 
					
						
						|  |  | 
					
						
						|  | pos = pos.reshape(-1) | 
					
						
						|  | out = np.einsum("m,d->md", pos, omega) | 
					
						
						|  |  | 
					
						
						|  | emb_sin = np.sin(out) | 
					
						
						|  | emb_cos = np.cos(out) | 
					
						
						|  |  | 
					
						
						|  | emb = np.concatenate([emb_sin, emb_cos], axis=1) | 
					
						
						|  | return emb | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _get_1d_sincos_embed_from_grid_torch(embed_dim: int, pos: torch.Tensor): | 
					
						
						|  | """ This is the torch version of *get_1d_sincos_pos_embed_from_grid()*. However, | 
					
						
						|  | it was modified to cast omega values to pos.dtype which must be float (and not int as in | 
					
						
						|  | regular positional embeddings). This was required in order to allow for native FSDP mixed | 
					
						
						|  | precision support: modify omega to appropriate dtype (pos carries the correct float dtype), | 
					
						
						|  | instead of manually forcing float32. | 
					
						
						|  |  | 
					
						
						|  | embed_dim: output dimension for each position | 
					
						
						|  | pos: a list of positions to be encoded: size (M,) - must be float dtype! | 
					
						
						|  | out: (M, D) | 
					
						
						|  | """ | 
					
						
						|  | assert embed_dim % 2 == 0 | 
					
						
						|  | assert pos.dtype in [torch.float32, torch.float16, torch.bfloat16] | 
					
						
						|  |  | 
					
						
						|  | omega = torch.arange(embed_dim // 2, dtype=pos.dtype).to(pos.device) | 
					
						
						|  | omega /= embed_dim / 2.0 | 
					
						
						|  | omega = 1.0 / 10000**omega | 
					
						
						|  |  | 
					
						
						|  | pos = pos.reshape(-1) | 
					
						
						|  | out = torch.einsum("m,d->md", pos, omega) | 
					
						
						|  |  | 
					
						
						|  | emb_sin = torch.sin(out) | 
					
						
						|  | emb_cos = torch.cos(out) | 
					
						
						|  |  | 
					
						
						|  | emb = torch.cat([emb_sin, emb_cos], dim=1) | 
					
						
						|  |  | 
					
						
						|  | return emb | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _init_weights(module): | 
					
						
						|  | """Initialize the weights""" | 
					
						
						|  | if isinstance(module, nn.Linear): | 
					
						
						|  | nn.init.xavier_uniform_(module.weight) | 
					
						
						|  | if module.bias is not None: | 
					
						
						|  | module.bias.data.zero_() | 
					
						
						|  | elif isinstance(module, nn.LayerNorm): | 
					
						
						|  | module.bias.data.zero_() | 
					
						
						|  | module.weight.data.fill_(1.0) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class PatchEmbed(nn.Module): | 
					
						
						|  | """3D version of timm.models.vision_transformer.PatchEmbed""" | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | input_size: Tuple[int, int, int] = (1, 224, 224), | 
					
						
						|  | patch_size: Tuple[int, int, int] = (1, 16, 16), | 
					
						
						|  | in_chans: int = 3, | 
					
						
						|  | embed_dim: int = 768, | 
					
						
						|  | norm_layer: nn.Module | None = None, | 
					
						
						|  | flatten: bool = True, | 
					
						
						|  | bias: bool = True, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.input_size = input_size | 
					
						
						|  | self.patch_size = patch_size | 
					
						
						|  | self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)] | 
					
						
						|  | self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2] | 
					
						
						|  | self.flatten = flatten | 
					
						
						|  |  | 
					
						
						|  | self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) | 
					
						
						|  | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | B, C, T, H, W = x.shape | 
					
						
						|  |  | 
					
						
						|  | if T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1: | 
					
						
						|  | logging.warning(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}." | 
					
						
						|  | f"The border will be ignored, add backbone_padding for pixel-wise tasks.") | 
					
						
						|  |  | 
					
						
						|  | x = self.proj(x) | 
					
						
						|  | if self.flatten: | 
					
						
						|  | x = x.flatten(2).transpose(1, 2) | 
					
						
						|  | x = self.norm(x) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class TemporalEncoder(nn.Module): | 
					
						
						|  | def __init__(self, embed_dim: int, trainable_scale: bool = False): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.embed_dim = embed_dim | 
					
						
						|  | self.year_embed_dim = embed_dim // 2 | 
					
						
						|  | self.julian_day_embed_dim = embed_dim - self.year_embed_dim | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if trainable_scale: | 
					
						
						|  | self.scale = nn.Parameter(torch.full((1,), 0.1)) | 
					
						
						|  | else: | 
					
						
						|  | self.register_buffer('scale', torch.ones(1)) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, temporal_coords: torch.Tensor, tokens_per_frame: int | None = None): | 
					
						
						|  | """ | 
					
						
						|  | temporal_coords: year and day-of-year info with shape (B, T, 2). | 
					
						
						|  | tokens_per_frame: number of tokens for each frame in the sample. If provided, embeddings will be | 
					
						
						|  | repeated over T dimension, and final shape is (B, T*tokens_per_frame, embed_dim). | 
					
						
						|  | """ | 
					
						
						|  | shape = temporal_coords.shape[:2] + (-1,) | 
					
						
						|  |  | 
					
						
						|  | year = _get_1d_sincos_embed_from_grid_torch( | 
					
						
						|  | self.year_embed_dim, temporal_coords[:, :, 0].flatten()).reshape(shape) | 
					
						
						|  | julian_day = _get_1d_sincos_embed_from_grid_torch( | 
					
						
						|  | self.julian_day_embed_dim, temporal_coords[:, :, 1].flatten()).reshape(shape) | 
					
						
						|  |  | 
					
						
						|  | embedding = self.scale * torch.cat([year, julian_day], dim=-1) | 
					
						
						|  |  | 
					
						
						|  | if tokens_per_frame is not None: | 
					
						
						|  | embedding = torch.repeat_interleave(embedding, tokens_per_frame, dim=1) | 
					
						
						|  |  | 
					
						
						|  | return embedding | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class LocationEncoder(nn.Module): | 
					
						
						|  | def __init__(self, embed_dim: int, trainable_scale: bool = False): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.embed_dim = embed_dim | 
					
						
						|  | self.lat_embed_dim = embed_dim // 2 | 
					
						
						|  | self.lon_embed_dim = embed_dim - self.lat_embed_dim | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if trainable_scale: | 
					
						
						|  | self.scale = nn.Parameter(torch.full((1,), 0.1)) | 
					
						
						|  | else: | 
					
						
						|  | self.register_buffer('scale', torch.ones(1)) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, location_coords: torch.Tensor): | 
					
						
						|  | """ | 
					
						
						|  | location_coords: lat and lon info with shape (B, 2). | 
					
						
						|  | """ | 
					
						
						|  | shape = location_coords.shape[:1] + (1, -1) | 
					
						
						|  |  | 
					
						
						|  | lat = _get_1d_sincos_embed_from_grid_torch( | 
					
						
						|  | self.lat_embed_dim, location_coords[:, 0].flatten()).reshape(shape) | 
					
						
						|  | lon = _get_1d_sincos_embed_from_grid_torch( | 
					
						
						|  | self.lon_embed_dim, location_coords[:, 1].flatten()).reshape(shape) | 
					
						
						|  |  | 
					
						
						|  | embedding = self.scale * torch.cat([lat, lon], dim=-1) | 
					
						
						|  |  | 
					
						
						|  | return embedding | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class PrithviViT(nn.Module): | 
					
						
						|  | """ Prithvi ViT Encoder""" | 
					
						
						|  | def __init__(self, | 
					
						
						|  | img_size: int | Tuple[int, int] = 224, | 
					
						
						|  | patch_size: int | Tuple[int, int, int] = (1, 16, 16), | 
					
						
						|  | num_frames: int = 1, | 
					
						
						|  | in_chans: int = 3, | 
					
						
						|  | embed_dim: int = 1024, | 
					
						
						|  | depth: int = 24, | 
					
						
						|  | num_heads: int = 16, | 
					
						
						|  | mlp_ratio: float = 4., | 
					
						
						|  | norm_layer: nn.Module = partial(torch.nn.LayerNorm, eps=1e-6), | 
					
						
						|  | coords_encoding: List[str] | None = None, | 
					
						
						|  | coords_scale_learn: bool = False, | 
					
						
						|  | encoder_only: bool = True, | 
					
						
						|  | ** kwargs, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.feature_info = [] | 
					
						
						|  | self.encoder_only = encoder_only | 
					
						
						|  | self.in_chans = in_chans | 
					
						
						|  | self.num_frames = num_frames | 
					
						
						|  | self.embed_dim = embed_dim | 
					
						
						|  | self.img_size = to_2tuple(img_size) | 
					
						
						|  | if isinstance(patch_size, int): | 
					
						
						|  | patch_size = (1, patch_size, patch_size) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.patch_embed = PatchEmbed( | 
					
						
						|  | input_size=(num_frames,) + self.img_size, | 
					
						
						|  | patch_size=patch_size, | 
					
						
						|  | in_chans=in_chans, | 
					
						
						|  | embed_dim=embed_dim, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | coords_encoding = coords_encoding or [] | 
					
						
						|  | self.temporal_encoding = 'time' in coords_encoding | 
					
						
						|  | self.location_encoding = 'location' in coords_encoding | 
					
						
						|  | if self.temporal_encoding: | 
					
						
						|  | assert patch_size[0] == 1, f"With temporal encoding, patch_size[0] must be 1, received {patch_size[0]}" | 
					
						
						|  | self.temporal_embed_enc = TemporalEncoder(embed_dim, coords_scale_learn) | 
					
						
						|  | if self.location_encoding: | 
					
						
						|  | self.location_embed_enc = LocationEncoder(embed_dim, coords_scale_learn) | 
					
						
						|  |  | 
					
						
						|  | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | 
					
						
						|  | self.register_buffer("pos_embed", torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.blocks = [] | 
					
						
						|  | for i in range(depth): | 
					
						
						|  | self.blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)) | 
					
						
						|  | self.feature_info.append( | 
					
						
						|  | {"num_chs": embed_dim * self.patch_embed.patch_size[0], "reduction": 1, "module": f"blocks.{i}"} | 
					
						
						|  | ) | 
					
						
						|  | self.blocks = nn.ModuleList(self.blocks) | 
					
						
						|  |  | 
					
						
						|  | self.norm = norm_layer(embed_dim) | 
					
						
						|  |  | 
					
						
						|  | self.initialize_weights() | 
					
						
						|  |  | 
					
						
						|  | def initialize_weights(self): | 
					
						
						|  |  | 
					
						
						|  | pos_embed = get_3d_sincos_pos_embed( | 
					
						
						|  | self.pos_embed.shape[-1], self.patch_embed.grid_size, add_cls_token=True | 
					
						
						|  | ) | 
					
						
						|  | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | w = self.patch_embed.proj.weight.data | 
					
						
						|  | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | torch.nn.init.normal_(self.cls_token, std=0.02) | 
					
						
						|  | self.apply(_init_weights) | 
					
						
						|  |  | 
					
						
						|  | def random_masking(self, sequence, mask_ratio, noise=None): | 
					
						
						|  | """ | 
					
						
						|  | Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random | 
					
						
						|  | noise. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | sequence (`torch.FloatTensor` of shape `(batch_size, sequence_length, dim)`) | 
					
						
						|  | mask_ratio (float): mask ratio to use. | 
					
						
						|  | noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is | 
					
						
						|  | mainly used for testing purposes to control randomness and maintain the reproducibility | 
					
						
						|  | """ | 
					
						
						|  | batch_size, seq_length, dim = sequence.shape | 
					
						
						|  | len_keep = int(seq_length * (1 - mask_ratio)) | 
					
						
						|  |  | 
					
						
						|  | if noise is None: | 
					
						
						|  | noise = torch.rand(batch_size, seq_length, device=sequence.device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ids_shuffle = torch.argsort(noise, dim=1).to(sequence.device) | 
					
						
						|  | ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ids_keep = ids_shuffle[:, :len_keep] | 
					
						
						|  | sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | mask = torch.ones([batch_size, seq_length], device=sequence.device) | 
					
						
						|  | mask[:, :len_keep] = 0 | 
					
						
						|  |  | 
					
						
						|  | mask = torch.gather(mask, dim=1, index=ids_restore) | 
					
						
						|  |  | 
					
						
						|  | return sequence_unmasked, mask, ids_restore | 
					
						
						|  |  | 
					
						
						|  | def _get_pos_embed(self, x): | 
					
						
						|  | t, h, w = x.shape[-3:] | 
					
						
						|  |  | 
					
						
						|  | pos_embed = torch.from_numpy(get_3d_sincos_pos_embed( | 
					
						
						|  | self.embed_dim, | 
					
						
						|  | ( | 
					
						
						|  | t // self.patch_embed.patch_size[0], | 
					
						
						|  | h // self.patch_embed.patch_size[1], | 
					
						
						|  | w // self.patch_embed.patch_size[2], | 
					
						
						|  | ), | 
					
						
						|  | add_cls_token=True, | 
					
						
						|  | )).float().unsqueeze(0).to(x) | 
					
						
						|  |  | 
					
						
						|  | return pos_embed | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, x: torch.Tensor, | 
					
						
						|  | temporal_coords: None | torch.Tensor = None, | 
					
						
						|  | location_coords: None | torch.Tensor = None, | 
					
						
						|  | mask_ratio=0.75 | 
					
						
						|  | ): | 
					
						
						|  | if x.shape[-3:] != self.patch_embed.input_size: | 
					
						
						|  |  | 
					
						
						|  | pos_embed = self._get_pos_embed(x) | 
					
						
						|  | else: | 
					
						
						|  | pos_embed = self.pos_embed | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x = self.patch_embed(x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x = x + pos_embed[:, 1:, :] | 
					
						
						|  |  | 
					
						
						|  | if self.temporal_encoding: | 
					
						
						|  | num_tokens_per_frame = x.shape[1] // self.num_frames | 
					
						
						|  | temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame) | 
					
						
						|  | x = x + temporal_encoding | 
					
						
						|  | if self.location_encoding: | 
					
						
						|  | location_encoding = self.location_embed_enc(location_coords) | 
					
						
						|  | x = x + location_encoding | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x, mask, ids_restore = self.random_masking(x, mask_ratio) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | cls_token = self.cls_token + pos_embed[:, :1, :] | 
					
						
						|  | cls_tokens = cls_token.expand(x.shape[0], -1, -1) | 
					
						
						|  | x = torch.cat((cls_tokens, x), dim=1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for block in self.blocks: | 
					
						
						|  | x = block(x) | 
					
						
						|  | x = self.norm(x) | 
					
						
						|  |  | 
					
						
						|  | return x, mask, ids_restore | 
					
						
						|  |  | 
					
						
						|  | def forward_features( | 
					
						
						|  | self, | 
					
						
						|  | x: torch.Tensor, | 
					
						
						|  | temporal_coords: None | torch.Tensor = None, | 
					
						
						|  | location_coords: None | torch.Tensor = None, | 
					
						
						|  | ) -> list[torch.Tensor]: | 
					
						
						|  | if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1: | 
					
						
						|  |  | 
					
						
						|  | x = x.unsqueeze(2) | 
					
						
						|  |  | 
					
						
						|  | if x.shape[-3:] != self.patch_embed.input_size: | 
					
						
						|  | pos_embed = self._get_pos_embed(x) | 
					
						
						|  | else: | 
					
						
						|  | pos_embed = self.pos_embed | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x = self.patch_embed(x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x = x + pos_embed[:, 1:, :] | 
					
						
						|  |  | 
					
						
						|  | if self.temporal_encoding: | 
					
						
						|  | num_tokens_per_frame = x.shape[1] // self.patch_embed.num_frames | 
					
						
						|  | temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame) | 
					
						
						|  | x = x + temporal_encoding | 
					
						
						|  | if self.location_encoding: | 
					
						
						|  | location_encoding = self.location_embed_enc(location_coords) | 
					
						
						|  | x = x + location_encoding | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | cls_token = self.cls_token + pos_embed[:, :1, :] | 
					
						
						|  | cls_tokens = cls_token.expand(x.shape[0], -1, -1) | 
					
						
						|  | x = torch.cat((cls_tokens, x), dim=1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | out = [] | 
					
						
						|  | for block in self.blocks: | 
					
						
						|  | x = block(x) | 
					
						
						|  | out.append(x.clone()) | 
					
						
						|  |  | 
					
						
						|  | x = self.norm(x) | 
					
						
						|  | out[-1] = x | 
					
						
						|  | return out | 
					
						
						|  |  | 
					
						
						|  | def prepare_features_for_image_model(self, features: list[torch.Tensor]) -> list[torch.Tensor]: | 
					
						
						|  | out = [] | 
					
						
						|  | effective_time_dim = self.patch_embed.input_size[0] // self.patch_embed.patch_size[0] | 
					
						
						|  | for x in features: | 
					
						
						|  | x_no_token = x[:, 1:, :] | 
					
						
						|  | number_of_tokens = x_no_token.shape[1] | 
					
						
						|  | tokens_per_timestep = number_of_tokens // effective_time_dim | 
					
						
						|  | h = int(np.sqrt(tokens_per_timestep)) | 
					
						
						|  | encoded = rearrange( | 
					
						
						|  | x_no_token, | 
					
						
						|  | "batch (t h w) e -> batch (t e) h w", | 
					
						
						|  | e=self.embed_dim, | 
					
						
						|  | t=effective_time_dim, | 
					
						
						|  | h=h, | 
					
						
						|  | ) | 
					
						
						|  | out.append(encoded) | 
					
						
						|  | return out | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class MAEDecoder(nn.Module): | 
					
						
						|  | """ Transformer Decoder used in the Prithvi MAE""" | 
					
						
						|  | def __init__(self, | 
					
						
						|  | patch_size: int | Tuple[int, int, int] = (1, 16, 16), | 
					
						
						|  | grid_size: List[int] | Tuple[int, int, int] = (3, 14, 14), | 
					
						
						|  | in_chans: int = 3, | 
					
						
						|  | encoder_embed_dim: int = 1024, | 
					
						
						|  | decoder_embed_dim: int = 512, | 
					
						
						|  | depth: int = 8, | 
					
						
						|  | num_heads: int = 16, | 
					
						
						|  | mlp_ratio: float = 4., | 
					
						
						|  | norm_layer: nn.Module = nn.LayerNorm, | 
					
						
						|  | coords_encoding: List[str] | None = None, | 
					
						
						|  | coords_scale_learn: bool = False, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True) | 
					
						
						|  | self.decoder_embed_dim = decoder_embed_dim | 
					
						
						|  | self.grid_size = grid_size | 
					
						
						|  | if isinstance(patch_size, int): | 
					
						
						|  | patch_size = (1, patch_size, patch_size) | 
					
						
						|  | self.patch_size = patch_size | 
					
						
						|  | self.num_frames = self.grid_size[0] * patch_size[0] | 
					
						
						|  | num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | coords_encoding = coords_encoding or [] | 
					
						
						|  | self.temporal_encoding = 'time' in coords_encoding | 
					
						
						|  | self.location_encoding = 'location' in coords_encoding | 
					
						
						|  | if self.temporal_encoding: | 
					
						
						|  | self.temporal_embed_dec = TemporalEncoder(decoder_embed_dim, coords_scale_learn) | 
					
						
						|  | if self.location_encoding: | 
					
						
						|  | self.location_embed_dec = LocationEncoder(decoder_embed_dim, coords_scale_learn) | 
					
						
						|  |  | 
					
						
						|  | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) | 
					
						
						|  |  | 
					
						
						|  | self.register_buffer("decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_embed_dim)) | 
					
						
						|  |  | 
					
						
						|  | self.decoder_blocks = nn.ModuleList( | 
					
						
						|  | [Block(decoder_embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for _ in range(depth)] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.decoder_norm = norm_layer(decoder_embed_dim) | 
					
						
						|  | self.decoder_pred = nn.Linear(decoder_embed_dim, | 
					
						
						|  | patch_size[0] * patch_size[1] * patch_size[2] * in_chans, | 
					
						
						|  | bias=True) | 
					
						
						|  |  | 
					
						
						|  | self.initialize_weights() | 
					
						
						|  |  | 
					
						
						|  | def initialize_weights(self): | 
					
						
						|  |  | 
					
						
						|  | decoder_pos_embed = get_3d_sincos_pos_embed( | 
					
						
						|  | self.decoder_pos_embed.shape[-1], self.grid_size, add_cls_token=True | 
					
						
						|  | ) | 
					
						
						|  | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | torch.nn.init.normal_(self.mask_token, std=0.02) | 
					
						
						|  | self.apply(_init_weights) | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, | 
					
						
						|  | hidden_states: torch.Tensor, | 
					
						
						|  | ids_restore: torch.Tensor, | 
					
						
						|  | temporal_coords: None | torch.Tensor = None, | 
					
						
						|  | location_coords: None | torch.Tensor = None, | 
					
						
						|  | input_size: list[int] = None, | 
					
						
						|  | ): | 
					
						
						|  |  | 
					
						
						|  | x = self.decoder_embed(hidden_states) | 
					
						
						|  |  | 
					
						
						|  | t, h, w = input_size[-3:] | 
					
						
						|  | decoder_pos_embed = torch.from_numpy( | 
					
						
						|  | get_3d_sincos_pos_embed( | 
					
						
						|  | self.decoder_embed_dim, | 
					
						
						|  | ( | 
					
						
						|  | t // self.patch_size[0], | 
					
						
						|  | h // self.patch_size[1], | 
					
						
						|  | w // self.patch_size[2], | 
					
						
						|  | ), | 
					
						
						|  | add_cls_token=True, | 
					
						
						|  | ) | 
					
						
						|  | ).to(x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) | 
					
						
						|  | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) | 
					
						
						|  |  | 
					
						
						|  | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device)) | 
					
						
						|  | x = torch.cat([x[:, :1, :], x_], dim=1) | 
					
						
						|  |  | 
					
						
						|  | x = x + decoder_pos_embed | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x_ = x[:, 1:, :] | 
					
						
						|  |  | 
					
						
						|  | if self.temporal_encoding: | 
					
						
						|  | num_tokens_per_frame = x_.shape[1] // self.num_frames | 
					
						
						|  | temporal_encoding = self.temporal_embed_dec(temporal_coords, num_tokens_per_frame) | 
					
						
						|  |  | 
					
						
						|  | x_ = x_ + temporal_encoding | 
					
						
						|  | if self.location_encoding: | 
					
						
						|  | location_encoding = self.location_embed_dec(location_coords) | 
					
						
						|  |  | 
					
						
						|  | x_ = x_ + location_encoding | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x = torch.cat([x[:, :1, :], x_], dim=1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for block in self.decoder_blocks: | 
					
						
						|  | x = block(x) | 
					
						
						|  | x = self.decoder_norm(x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | pred = self.decoder_pred(x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | pred = pred[:, 1:, :] | 
					
						
						|  |  | 
					
						
						|  | return pred | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class PrithviMAE(nn.Module): | 
					
						
						|  | """ Prithvi Masked Autoencoder""" | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, | 
					
						
						|  | img_size: int | Tuple[int, int] = 224, | 
					
						
						|  | patch_size: int | Tuple[int, int, int] = (1, 16, 16), | 
					
						
						|  | num_frames: int = 3, | 
					
						
						|  | in_chans: int = 3, | 
					
						
						|  | embed_dim: int = 1024, | 
					
						
						|  | depth: int = 24, | 
					
						
						|  | num_heads: int = 16, | 
					
						
						|  | decoder_embed_dim: int = 512, | 
					
						
						|  | decoder_depth: int = 8, | 
					
						
						|  | decoder_num_heads: int = 16, | 
					
						
						|  | mlp_ratio: float = 4., | 
					
						
						|  | norm_layer: nn.Module = partial(torch.nn.LayerNorm, eps=1e-6), | 
					
						
						|  | norm_pix_loss: bool = False, | 
					
						
						|  | coords_encoding: List[str] | None = None, | 
					
						
						|  | coords_scale_learn: bool = False, | 
					
						
						|  | encoder_only: bool = False, | 
					
						
						|  | **kwargs, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.encoder = PrithviViT( | 
					
						
						|  | img_size=img_size, | 
					
						
						|  | num_frames=num_frames, | 
					
						
						|  | patch_size=patch_size, | 
					
						
						|  | in_chans=in_chans, | 
					
						
						|  | embed_dim=embed_dim, | 
					
						
						|  | depth=depth, | 
					
						
						|  | num_heads=num_heads, | 
					
						
						|  | mlp_ratio=mlp_ratio, | 
					
						
						|  | norm_layer=norm_layer, | 
					
						
						|  | coords_encoding=coords_encoding, | 
					
						
						|  | coords_scale_learn=coords_scale_learn, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.encoder_only = encoder_only | 
					
						
						|  |  | 
					
						
						|  | if not encoder_only: | 
					
						
						|  | self.decoder = MAEDecoder( | 
					
						
						|  | patch_size=patch_size, | 
					
						
						|  | grid_size=self.encoder.patch_embed.grid_size, | 
					
						
						|  | in_chans=in_chans, | 
					
						
						|  | encoder_embed_dim=embed_dim, | 
					
						
						|  | decoder_embed_dim=decoder_embed_dim, | 
					
						
						|  | depth=decoder_depth, | 
					
						
						|  | num_heads=decoder_num_heads, | 
					
						
						|  | mlp_ratio=mlp_ratio, | 
					
						
						|  | norm_layer=norm_layer, | 
					
						
						|  | coords_encoding=coords_encoding, | 
					
						
						|  | coords_scale_learn=coords_scale_learn, | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | self.decoder = nn.Identity() | 
					
						
						|  |  | 
					
						
						|  | self.norm_pix_loss = norm_pix_loss | 
					
						
						|  |  | 
					
						
						|  | def patchify(self, pixel_values): | 
					
						
						|  | """ | 
					
						
						|  | Args: | 
					
						
						|  | pixel_values (torch.FloatTensor of shape `(batch_size, num_channels, time, height, width)`): | 
					
						
						|  | Pixel values. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | torch.FloatTensor of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: | 
					
						
						|  | Patchified pixel values. | 
					
						
						|  | """ | 
					
						
						|  | patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size | 
					
						
						|  | num_channels = self.encoder.in_chans | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | patchified_pixel_values = rearrange(pixel_values, 'b c (t s) (h p) (w q) -> b (t h w) (s p q c)', | 
					
						
						|  | c=num_channels, s=patch_size_t, p=patch_size_h, q=patch_size_w) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | return patchified_pixel_values | 
					
						
						|  |  | 
					
						
						|  | def unpatchify(self, patchified_pixel_values, image_size: Tuple[int, int] | None = None): | 
					
						
						|  | """ | 
					
						
						|  | Args: | 
					
						
						|  | patchified_pixel_values (`torch.FloatTensor` of shape | 
					
						
						|  | `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: | 
					
						
						|  | Patchified pixel values. | 
					
						
						|  | image_size (`Tuple[int, int]`, *optional*): | 
					
						
						|  | Original image size. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`: | 
					
						
						|  | Pixel values. | 
					
						
						|  | """ | 
					
						
						|  | patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size | 
					
						
						|  | image_size = to_2tuple(image_size) if image_size is not None else self.encoder.img_size | 
					
						
						|  | original_height, original_width = image_size | 
					
						
						|  | num_patches_h = original_height // patch_size_h | 
					
						
						|  | num_patches_w = original_width // patch_size_w | 
					
						
						|  | num_channels = self.encoder.in_chans | 
					
						
						|  |  | 
					
						
						|  | pixel_values = rearrange(patchified_pixel_values, 'b (t h w) (s p q c) -> b c (t s) (h p) (w q)', | 
					
						
						|  | c=num_channels, h=num_patches_h, w=num_patches_w, | 
					
						
						|  | s=patch_size_t, p=patch_size_h, q=patch_size_w) | 
					
						
						|  | return pixel_values | 
					
						
						|  |  | 
					
						
						|  | def forward_loss(self, pixel_values, pred, mask): | 
					
						
						|  | """ | 
					
						
						|  | Args: | 
					
						
						|  | pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, time, height, width)`): | 
					
						
						|  | Pixel values. | 
					
						
						|  | pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: | 
					
						
						|  | Predicted pixel values. | 
					
						
						|  | mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): | 
					
						
						|  | Tensor indicating which patches are masked (1) and which are not (0). | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | `torch.FloatTensor`: Pixel reconstruction loss. | 
					
						
						|  | """ | 
					
						
						|  | target = self.patchify(pixel_values) | 
					
						
						|  | if self.norm_pix_loss: | 
					
						
						|  | mean = target.mean(dim=-1, keepdim=True) | 
					
						
						|  | var = target.var(dim=-1, keepdim=True) | 
					
						
						|  | target = (target - mean) / (var + 1.0e-6) ** 0.5 | 
					
						
						|  |  | 
					
						
						|  | loss = (pred - target) ** 2 | 
					
						
						|  | loss = loss.mean(dim=-1) | 
					
						
						|  | loss = (loss * mask).sum() / mask.sum() | 
					
						
						|  | return loss | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, | 
					
						
						|  | pixel_values: torch.Tensor, | 
					
						
						|  | temporal_coords: None | torch.Tensor = None, | 
					
						
						|  | location_coords: None | torch.Tensor = None, | 
					
						
						|  | mask_ratio: float = 0.75 | 
					
						
						|  | ): | 
					
						
						|  | if len(pixel_values.shape) == 4 and self.encoder.patch_embed.input_size[0] == 1: | 
					
						
						|  |  | 
					
						
						|  | pixel_values = pixel_values.unsqueeze(2) | 
					
						
						|  |  | 
					
						
						|  | latent, mask, ids_restore = self.encoder(pixel_values, temporal_coords, location_coords, mask_ratio) | 
					
						
						|  | pred = self.decoder(latent, ids_restore, temporal_coords, location_coords, input_size=pixel_values.shape) | 
					
						
						|  | loss = self.forward_loss(pixel_values, pred, mask) | 
					
						
						|  | return loss, pred, mask | 
					
						
						|  |  | 
					
						
						|  | def forward_features( | 
					
						
						|  | self, | 
					
						
						|  | x: torch.Tensor, | 
					
						
						|  | temporal_coords: None | torch.Tensor = None, | 
					
						
						|  | location_coords: None | torch.Tensor = None, | 
					
						
						|  | ) -> List[torch.Tensor]: | 
					
						
						|  | return self.encoder.forward_features(x, temporal_coords, location_coords) | 
					
						
						|  |  |