Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		Janne Hellsten
		
	commited on
		
		
					Commit 
							
							·
						
						f7e4867
	
1
								Parent(s):
							
							d3a616a
								
Add --allow-tf32 perf tuning argument that can be used to enable tf32
Browse filesDefaults to keeping tf32 disabled.  This is because we haven't fully
verified training results with fp32 enabled.
- docs/train-help.txt +1 -0
- train.py +8 -0
- training/training_loop.py +3 -0
    	
        docs/train-help.txt
    CHANGED
    
    | @@ -65,5 +65,6 @@ Options: | |
| 65 | 
             
              --fp32 BOOL                     Disable mixed-precision training
         | 
| 66 | 
             
              --nhwc BOOL                     Use NHWC memory format with FP16
         | 
| 67 | 
             
              --nobench BOOL                  Disable cuDNN benchmarking
         | 
|  | |
| 68 | 
             
              --workers INT                   Override number of DataLoader workers
         | 
| 69 | 
             
              --help                          Show this message and exit.
         | 
|  | |
| 65 | 
             
              --fp32 BOOL                     Disable mixed-precision training
         | 
| 66 | 
             
              --nhwc BOOL                     Use NHWC memory format with FP16
         | 
| 67 | 
             
              --nobench BOOL                  Disable cuDNN benchmarking
         | 
| 68 | 
            +
              --allow-tf32 BOOL               Allow PyTorch to use TF32 internally
         | 
| 69 | 
             
              --workers INT                   Override number of DataLoader workers
         | 
| 70 | 
             
              --help                          Show this message and exit.
         | 
    	
        train.py
    CHANGED
    
    | @@ -61,6 +61,7 @@ def setup_training_loop_kwargs( | |
| 61 | 
             
                # Performance options (not included in desc).
         | 
| 62 | 
             
                fp32       = None, # Disable mixed-precision training: <bool>, default = False
         | 
| 63 | 
             
                nhwc       = None, # Use NHWC memory format with FP16: <bool>, default = False
         | 
|  | |
| 64 | 
             
                nobench    = None, # Disable cuDNN benchmarking: <bool>, default = False
         | 
| 65 | 
             
                workers    = None, # Override number of DataLoader workers: <int>, default = 3
         | 
| 66 | 
             
            ):
         | 
| @@ -343,6 +344,12 @@ def setup_training_loop_kwargs( | |
| 343 | 
             
                if nobench:
         | 
| 344 | 
             
                    args.cudnn_benchmark = False
         | 
| 345 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 346 | 
             
                if workers is not None:
         | 
| 347 | 
             
                    assert isinstance(workers, int)
         | 
| 348 | 
             
                    if not workers >= 1:
         | 
| @@ -425,6 +432,7 @@ class CommaSeparatedList(click.ParamType): | |
| 425 | 
             
            @click.option('--fp32', help='Disable mixed-precision training', type=bool, metavar='BOOL')
         | 
| 426 | 
             
            @click.option('--nhwc', help='Use NHWC memory format with FP16', type=bool, metavar='BOOL')
         | 
| 427 | 
             
            @click.option('--nobench', help='Disable cuDNN benchmarking', type=bool, metavar='BOOL')
         | 
|  | |
| 428 | 
             
            @click.option('--workers', help='Override number of DataLoader workers', type=int, metavar='INT')
         | 
| 429 |  | 
| 430 | 
             
            def main(ctx, outdir, dry_run, **config_kwargs):
         | 
|  | |
| 61 | 
             
                # Performance options (not included in desc).
         | 
| 62 | 
             
                fp32       = None, # Disable mixed-precision training: <bool>, default = False
         | 
| 63 | 
             
                nhwc       = None, # Use NHWC memory format with FP16: <bool>, default = False
         | 
| 64 | 
            +
                allow_tf32 = None, # Allow PyTorch to use TF32 for matmul and convolutions: <bool>, default = False
         | 
| 65 | 
             
                nobench    = None, # Disable cuDNN benchmarking: <bool>, default = False
         | 
| 66 | 
             
                workers    = None, # Override number of DataLoader workers: <int>, default = 3
         | 
| 67 | 
             
            ):
         | 
|  | |
| 344 | 
             
                if nobench:
         | 
| 345 | 
             
                    args.cudnn_benchmark = False
         | 
| 346 |  | 
| 347 | 
            +
                if allow_tf32 is None:
         | 
| 348 | 
            +
                    allow_tf32 = False
         | 
| 349 | 
            +
                assert isinstance(allow_tf32, bool)
         | 
| 350 | 
            +
                if allow_tf32:
         | 
| 351 | 
            +
                    args.allow_tf32 = True
         | 
| 352 | 
            +
             | 
| 353 | 
             
                if workers is not None:
         | 
| 354 | 
             
                    assert isinstance(workers, int)
         | 
| 355 | 
             
                    if not workers >= 1:
         | 
|  | |
| 432 | 
             
            @click.option('--fp32', help='Disable mixed-precision training', type=bool, metavar='BOOL')
         | 
| 433 | 
             
            @click.option('--nhwc', help='Use NHWC memory format with FP16', type=bool, metavar='BOOL')
         | 
| 434 | 
             
            @click.option('--nobench', help='Disable cuDNN benchmarking', type=bool, metavar='BOOL')
         | 
| 435 | 
            +
            @click.option('--allow-tf32', help='Allow PyTorch to use TF32 internally', type=bool, metavar='BOOL')
         | 
| 436 | 
             
            @click.option('--workers', help='Override number of DataLoader workers', type=int, metavar='INT')
         | 
| 437 |  | 
| 438 | 
             
            def main(ctx, outdir, dry_run, **config_kwargs):
         | 
    	
        training/training_loop.py
    CHANGED
    
    | @@ -115,6 +115,7 @@ def training_loop( | |
| 115 | 
             
                network_snapshot_ticks  = 50,       # How often to save network snapshots? None = disable.
         | 
| 116 | 
             
                resume_pkl              = None,     # Network pickle to resume training from.
         | 
| 117 | 
             
                cudnn_benchmark         = True,     # Enable torch.backends.cudnn.benchmark?
         | 
|  | |
| 118 | 
             
                abort_fn                = None,     # Callback function for determining whether to abort training. Must return consistent results across ranks.
         | 
| 119 | 
             
                progress_fn             = None,     # Callback function for updating training progress. Called for all ranks.
         | 
| 120 | 
             
            ):
         | 
| @@ -124,6 +125,8 @@ def training_loop( | |
| 124 | 
             
                np.random.seed(random_seed * num_gpus + rank)
         | 
| 125 | 
             
                torch.manual_seed(random_seed * num_gpus + rank)
         | 
| 126 | 
             
                torch.backends.cudnn.benchmark = cudnn_benchmark    # Improves training speed.
         | 
|  | |
|  | |
| 127 | 
             
                conv2d_gradfix.enabled = True                       # Improves training speed.
         | 
| 128 | 
             
                grid_sample_gradfix.enabled = True                  # Avoids errors with the augmentation pipe.
         | 
| 129 |  | 
|  | |
| 115 | 
             
                network_snapshot_ticks  = 50,       # How often to save network snapshots? None = disable.
         | 
| 116 | 
             
                resume_pkl              = None,     # Network pickle to resume training from.
         | 
| 117 | 
             
                cudnn_benchmark         = True,     # Enable torch.backends.cudnn.benchmark?
         | 
| 118 | 
            +
                allow_tf32              = False,    # Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnn.allow_tf32?
         | 
| 119 | 
             
                abort_fn                = None,     # Callback function for determining whether to abort training. Must return consistent results across ranks.
         | 
| 120 | 
             
                progress_fn             = None,     # Callback function for updating training progress. Called for all ranks.
         | 
| 121 | 
             
            ):
         | 
|  | |
| 125 | 
             
                np.random.seed(random_seed * num_gpus + rank)
         | 
| 126 | 
             
                torch.manual_seed(random_seed * num_gpus + rank)
         | 
| 127 | 
             
                torch.backends.cudnn.benchmark = cudnn_benchmark    # Improves training speed.
         | 
| 128 | 
            +
                torch.backends.cuda.matmul.allow_tf32 = allow_tf32  # Allow PyTorch to internally use tf32 for matmul
         | 
| 129 | 
            +
                torch.backends.cudnn.allow_tf32 = allow_tf32        # Allow PyTorch to internally use tf32 for convolutions
         | 
| 130 | 
             
                conv2d_gradfix.enabled = True                       # Improves training speed.
         | 
| 131 | 
             
                grid_sample_gradfix.enabled = True                  # Avoids errors with the augmentation pipe.
         | 
| 132 |  | 
