Spaces:
Build error
Build error
| import datetime | |
| from argparse import ArgumentParser | |
| import torch | |
| from lightning import Trainer | |
| from lightning.pytorch.loggers import TensorBoardLogger | |
| from lightning.pytorch.callbacks import ModelSummary | |
| from src.trainer import ViTLightningModule | |
| def main(): | |
| """ Neural network trainer entry point. """ | |
| parser = ArgumentParser(description='KAUST-SDAIA Diabetic Retinopathy') | |
| parser.add_argument('--tag', action='store', type=str, | |
| help='Extra suffix to put on the artefact dir name') | |
| parser.add_argument('--debug', action='store_true', | |
| help="Dummy training cycle for testing purposes") | |
| parser.add_argument('--convert-checkpoint', action='store', type=str, | |
| help='Convert a checkpoint from training to pickle-independent ' | |
| 'predictor-compatible directory') | |
| args = parser.parse_args() | |
| torch.set_float32_matmul_precision('high') # for V100/A100 | |
| if args.convert_checkpoint is not None: | |
| print("Converting checkpoint", args.convert_checkpoint) | |
| checkpoint = torch.load(args.convert_checkpoint, map_location="cpu") | |
| print(list(checkpoint.keys())) | |
| model = ViTLightningModule.load_from_checkpoint( | |
| args.convert_checkpoint, | |
| map_location="cpu", | |
| hparams_file="tmp_ckpt_deleteme.yaml") | |
| model.save_checkpoint_dk("tmp_checkp_path_deleteme") | |
| print("Saved checkpoint. Done.") | |
| else: | |
| print("Start training") | |
| fast_dev_run = True if args.debug == True else False | |
| model = ViTLightningModule(fast_dev_run) | |
| datetime_str = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
| art_dir_name = (f"{datetime_str}" + | |
| (f"_{args.tag}" if args.tag is not None else "")) | |
| logger = TensorBoardLogger(save_dir=".", name="lightning_logs", version=art_dir_name) | |
| trainer = Trainer( | |
| logger=logger, | |
| benchmark=True, | |
| devices="auto", | |
| accelerator="auto", | |
| max_epochs=-1, | |
| callbacks=[ | |
| ModelSummary(max_depth=-1), | |
| ], | |
| fast_dev_run=fast_dev_run, | |
| log_every_n_steps=10, | |
| ) | |
| trainer.fit( | |
| model, | |
| train_dataloaders=model._train_dataloader, | |
| val_dataloaders=model._val_dataloader, | |
| ) | |
| print("Training done") | |
| if __name__ == "__main__": | |
| main() | |