Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	| # KAGGLE SPECIFIC: This script is used to make training compatible with Kaggle's notebook environment. | |
| import torch | |
| import os | |
| def load_and_save_checkpoint(input_filename, output_filename, device): | |
| if os.path.isfile(input_filename): | |
| print(f"Loading checkpoint '{input_filename}'") | |
| checkpoint = torch.load(input_filename, map_location=device) | |
| # Extract only the necessary state | |
| save_state = { | |
| 'epoch': checkpoint['epoch'], | |
| 'generator_state_dict': checkpoint['generator_state_dict'], | |
| 'discriminator_state_dict': checkpoint['discriminator_state_dict'], | |
| 'optimizerG_state_dict': checkpoint['optimizerG_state_dict'], | |
| 'optimizerD_state_dict': checkpoint['optimizerD_state_dict'], | |
| } | |
| # Save the checkpoint | |
| torch.save(save_state, output_filename) | |
| print(f"Saved checkpoint to '{output_filename}'") | |
| else: | |
| print(f"No checkpoint found at '{input_filename}'") | |
| if __name__ == "__main__": | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| input_checkpoint = "checkpoints/latest_checkpoint.pth.tar" | |
| output_checkpoint = "checkpoints/converted_checkpoint.pth.tar" | |
| load_and_save_checkpoint(input_checkpoint, output_checkpoint, device) |