Spaces:
Runtime error
Runtime error
| import os | |
| import shutil | |
| from mmcv import Config | |
| from pytorch_lightning import Trainer | |
| from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping | |
| from pytorch_lightning.loggers import WandbLogger | |
| from pytorch_lightning.utilities.seed import seed_everything | |
| import wandb | |
| from risk_biased.utils.callbacks import SwitchTrainingModeCallback | |
| from risk_biased.utils.callbacks import ( | |
| HistogramCallback, | |
| PlotTrajCallback, | |
| DrawCallbackParams, | |
| ) | |
| from risk_biased.utils.load_model import load_from_config | |
| from scripts.scripts_utils.load_utils import get_config | |
| def create_log_dir(): | |
| working_dir = os.path.dirname(os.path.realpath(__file__)) | |
| log_dir = os.path.join(working_dir, "logs") | |
| if not os.path.exists(log_dir): | |
| os.mkdir(log_dir) | |
| return log_dir | |
| def save_log_config(cfg: Config, predictor): | |
| # Save and log the config (not only a copy of the config file because settings may have been overwritten by argparse) | |
| log_config_path = os.path.join(wandb.run.dir, "learning_config.py") | |
| cfg.dump(log_config_path) | |
| wandb.save(log_config_path) | |
| # Save files listed in the current wandb log dir | |
| for file_name in cfg.files_to_log: | |
| dest_path = os.path.join(wandb.run.dir, os.path.basename(file_name)) | |
| shutil.copy(file_name, dest_path) | |
| wandb.save(dest_path) | |
| if cfg.log_weights_and_grads: | |
| wandb.watch(predictor, log="all", log_freq=100) | |
| def create_callbacks(cfg: Config, log_dir: str, is_interaction: bool) -> list: | |
| # Save checkpoint of last model in a specific directory | |
| last_run_checkpoint_callback = ModelCheckpoint( | |
| monitor="val/minfde/prior", | |
| mode="min", | |
| filename="epoch={epoch:02d}-step={step}-val_minfde_prior={val/minfde/prior:.2f}", | |
| auto_insert_metric_name=False, | |
| dirpath=os.path.join(log_dir, "checkpoints_last_run"), | |
| save_last=True, | |
| ) | |
| # Save checkpoints of current run in current wandb log dir | |
| checkpoint_callback = ModelCheckpoint( | |
| monitor="val/minfde/prior", | |
| mode="min", | |
| filename="epoch={epoch:02d}-step={step}-val_minfde_prior={val/minfde/prior:.2f}", | |
| auto_insert_metric_name=False, | |
| dirpath=wandb.run.dir, | |
| save_last=True, | |
| ) | |
| callbacks = [ | |
| last_run_checkpoint_callback, | |
| checkpoint_callback, | |
| ] | |
| if not is_interaction: | |
| histogram_callback = HistogramCallback( | |
| params=DrawCallbackParams.from_config(cfg), | |
| n_samples=1000, | |
| ) | |
| plot_callback = PlotTrajCallback( | |
| params=DrawCallbackParams.from_config(cfg), n_samples=10 | |
| ) | |
| callbacks.append(histogram_callback) | |
| callbacks.append(plot_callback) | |
| if cfg.early_stopping: | |
| early_stopping_callback = EarlyStopping( | |
| monitor="val/minfde/prior", | |
| min_delta=-0.2, | |
| patience=5, | |
| verbose=False, | |
| mode="min", | |
| ) | |
| callbacks.append(early_stopping_callback) | |
| switch_mode_callback = SwitchTrainingModeCallback( | |
| switch_at_epoch=cfg.num_epochs_cvae | |
| ) | |
| callbacks.append(switch_mode_callback) | |
| return callbacks | |
| def get_trainer(cfg: Config, logger: WandbLogger, callbacks: list) -> Trainer: | |
| num_epochs = cfg.num_epochs_cvae + cfg.num_epochs_bias | |
| return Trainer( | |
| gpus=cfg.gpus, | |
| max_epochs=num_epochs, | |
| logger=logger, | |
| val_check_interval=float(cfg.val_check_interval_epoch), | |
| accumulate_grad_batches=cfg.accumulate_grad_batches, | |
| callbacks=callbacks, | |
| ) | |
| def main(is_interaction: bool = False): | |
| log_dir = create_log_dir() | |
| cfg = get_config(log_dir, is_interaction) | |
| predictor, dataloaders, cfg = load_from_config(cfg) | |
| if cfg.seed is not None: | |
| seed_everything(cfg.seed) | |
| save_log_config(cfg, predictor) | |
| logger = WandbLogger( | |
| project=cfg.project, log_model=True, save_dir=log_dir, id=wandb.run.id | |
| ) | |
| callbacks = create_callbacks(cfg, log_dir, is_interaction) | |
| trainer = get_trainer(cfg, logger, callbacks) | |
| trainer.fit( | |
| predictor, | |
| train_dataloaders=dataloaders.train_dataloader(), | |
| val_dataloaders=dataloaders.val_dataloader(), | |
| ) | |
| wandb.finish() | |
| if __name__ == "__main__": | |
| main(is_interaction=True) | |