Spaces:
Running
on
Zero
Running
on
Zero
| import lightning.pytorch as pl | |
| from lightning.pytorch import Callback | |
| import os.path | |
| import numpy | |
| from typing import Sequence, Any, Dict | |
| from concurrent.futures import ThreadPoolExecutor | |
| from lightning.pytorch.utilities.types import STEP_OUTPUT | |
| from lightning_utilities.core.rank_zero import rank_zero_info | |
| class SaveImagesHook(Callback): | |
| def __init__(self, save_dir="val", save_compressed=False): | |
| self.save_dir = save_dir | |
| self.save_compressed = save_compressed | |
| def save_start(self, target_dir): | |
| self.samples = [] | |
| self.target_dir = target_dir | |
| self.executor_pool = ThreadPoolExecutor(max_workers=8) | |
| if not os.path.exists(self.target_dir): | |
| os.makedirs(self.target_dir, exist_ok=True) | |
| else: | |
| if os.listdir(target_dir) and "debug" not in str(target_dir): | |
| raise FileExistsError(f'{self.target_dir} already exists and not empty!') | |
| rank_zero_info(f"Save images to {self.target_dir}") | |
| self._saved_num = 0 | |
| def save_image(self, trainer, pl_module, images, metadatas,): | |
| images = images.permute(0, 2, 3, 1).cpu().numpy() | |
| for sample, metadata in zip(images, metadatas): | |
| save_fn = metadata.pop("save_fn", None) | |
| self.executor_pool.submit(save_fn, sample, metadata, self.target_dir) | |
| def process_batch( | |
| self, | |
| trainer: "pl.Trainer", | |
| pl_module: "pl.LightningModule", | |
| samples: STEP_OUTPUT, | |
| batch: Any, | |
| ) -> None: | |
| xT, y, metadata = batch | |
| b, c, h, w = samples.shape | |
| if not self.save_compressed or self._saved_num < 10: | |
| self._saved_num += b | |
| self.save_image(trainer, pl_module, samples, metadata) | |
| all_samples = pl_module.all_gather(samples).view(-1, c, h, w) | |
| if trainer.is_global_zero: | |
| all_samples = all_samples.permute(0, 2, 3, 1).cpu().numpy() | |
| self.samples.append(all_samples) | |
| def save_end(self): | |
| if self.save_compressed and len(self.samples) > 0: | |
| samples = numpy.concatenate(self.samples) | |
| numpy.savez(f'{self.target_dir}/output.npz', arr_0=samples) | |
| self.executor_pool.shutdown(wait=True) | |
| self.target_dir = None | |
| self.executor_pool = None | |
| self.samples = [] | |
| def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | |
| target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}") | |
| self.save_start(target_dir) | |
| def on_validation_batch_end( | |
| self, | |
| trainer: "pl.Trainer", | |
| pl_module: "pl.LightningModule", | |
| outputs: STEP_OUTPUT, | |
| batch: Any, | |
| batch_idx: int, | |
| dataloader_idx: int = 0, | |
| ) -> None: | |
| return self.process_batch(trainer, pl_module, outputs, batch) | |
| def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | |
| self.save_end() | |
| def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | |
| target_dir = os.path.join(trainer.default_root_dir, self.save_dir, "predict") | |
| self.save_start(target_dir) | |
| def on_predict_batch_end( | |
| self, | |
| trainer: "pl.Trainer", | |
| pl_module: "pl.LightningModule", | |
| samples: Any, | |
| batch: Any, | |
| batch_idx: int, | |
| dataloader_idx: int = 0, | |
| ) -> None: | |
| return self.process_batch(trainer, pl_module, samples, batch) | |
| def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | |
| self.save_end() | |
| def state_dict(self) -> Dict[str, Any]: | |
| return dict() |