Spaces:
Build error
Build error
| import torch | |
| from torch import Tensor, nn | |
| def rollout_iter( | |
| nsteps: int, | |
| model: nn.Module, | |
| batch: dict[str, Tensor | int | float], | |
| ) -> Tensor: | |
| """A helper function for performing autoregressive rollout. | |
| Args: | |
| nsteps (int): The number of rollout steps to take | |
| model (nn.Module): A model. | |
| batch (dict): A data dictionary common to the Prithvi models. | |
| Raises: | |
| ValueError: If the number of steps isn't positive. | |
| Returns: | |
| Tensor: the output of the model after nsteps autoregressive iterations. | |
| """ | |
| if nsteps < 1: | |
| raise ValueError("'nsteps' shouold be a positive int.") | |
| xlast = batch["x"][:, 1] | |
| batch["lead_time"] = batch["lead_time"][..., 0] | |
| # Save the masking ratio to be restored later | |
| mask_ratio_tmp = model.mask_ratio_inputs | |
| for step in range(nsteps): | |
| # After first step, turn off masking | |
| if step > 0: | |
| model.mask_ratio_inputs = 0.0 | |
| batch["static"] = batch["statics"][:, step] | |
| batch["climate"] = batch["climates"][:, step] | |
| batch["y"] = batch["ys"][:, step] | |
| out = model(batch) | |
| batch["x"] = torch.cat((xlast[:, None], out[:, None]), dim=1) | |
| xlast = out | |
| # Restore the masking ratio | |
| model.mask_ratio_inputs = mask_ratio_tmp | |
| return xlast | |