Spaces:
Runtime error
Runtime error
| import os | |
| import diffusion | |
| from pytorch_lightning.callbacks import ( | |
| ModelCheckpoint, | |
| LearningRateMonitor | |
| ) | |
| class ModelCallback: | |
| def __init__( | |
| self, | |
| root_path: str, | |
| ckpt_monitor: str = "val_loss", | |
| ckpt_mode: str = "min", | |
| ): | |
| ckpt_path = os.path.join(os.path.join(root_path, "model/")) | |
| if not os.path.exists(root_path): | |
| os.makedirs(root_path) | |
| if not os.path.exists(ckpt_path): | |
| os.makedirs(ckpt_path) | |
| self.ckpt_callback = ModelCheckpoint( | |
| monitor=ckpt_monitor, | |
| dirpath=ckpt_path, | |
| filename="model", | |
| save_top_k=1, | |
| mode=ckpt_mode, | |
| save_weights_only=True | |
| ) | |
| self.lr_callback = LearningRateMonitor("step") | |
| self.ema_callback = diffusion.EMACallback(decay=0.995) | |
| def get_callback(self): | |
| return [ | |
| self.ckpt_callback, self.lr_callback, self.ema_callback | |
| ] | |