Spaces:
Build error
Build error
| from typing import Dict, Any | |
| import os | |
| import numpy as np | |
| import pytorch_lightning as pl | |
| from pytorch_lightning.callbacks import ModelCheckpoint | |
| from pytorch_lightning.utilities.types import STEP_OUTPUT | |
| import torch | |
| import torchvision | |
| from PIL import Image | |
| from pytorch_lightning.callbacks import Callback | |
| from pytorch_lightning.utilities.distributed import rank_zero_only | |
| from .mixins import ImageLoggerMixin | |
| __all__ = [ | |
| "ModelCheckpoint", | |
| "ImageLogger" | |
| ] | |
| class ImageLogger(Callback): | |
| """ | |
| Log images during training or validating. | |
| TODO: Support validating. | |
| """ | |
| def __init__( | |
| self, | |
| log_every_n_steps: int=2000, | |
| max_images_each_step: int=4, | |
| log_images_kwargs: Dict[str, Any]=None | |
| ) -> "ImageLogger": | |
| super().__init__() | |
| self.log_every_n_steps = log_every_n_steps | |
| self.max_images_each_step = max_images_each_step | |
| self.log_images_kwargs = log_images_kwargs or dict() | |
| def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: | |
| assert isinstance(pl_module, ImageLoggerMixin) | |
| def on_train_batch_end( | |
| self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: STEP_OUTPUT, | |
| batch: Any, batch_idx: int, dataloader_idx: int | |
| ) -> None: | |
| if pl_module.global_step % self.log_every_n_steps == 0: | |
| is_train = pl_module.training | |
| if is_train: | |
| pl_module.freeze() | |
| with torch.no_grad(): | |
| # returned images should be: nchw, rgb, [0, 1] | |
| images: Dict[str, torch.Tensor] = pl_module.log_images(batch, **self.log_images_kwargs) | |
| # save images | |
| save_dir = os.path.join(pl_module.logger.save_dir, "image_log", "train") | |
| os.makedirs(save_dir, exist_ok=True) | |
| for image_key in images: | |
| image = images[image_key].detach().cpu() | |
| N = min(self.max_images_each_step, len(image)) | |
| grid = torchvision.utils.make_grid(image[:N], nrow=4) | |
| # chw -> hwc (hw if gray) | |
| grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1).numpy() | |
| grid = (grid * 255).clip(0, 255).astype(np.uint8) | |
| filename = "{}_step-{:06}_e-{:06}_b-{:06}.png".format( | |
| image_key, pl_module.global_step, pl_module.current_epoch, batch_idx | |
| ) | |
| path = os.path.join(save_dir, filename) | |
| Image.fromarray(grid).save(path) | |
| if is_train: | |
| pl_module.unfreeze() | |