Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import IterableDataset, DataLoader | |
| from torchvision import transforms | |
| from torchvision.utils import make_grid | |
| import mlflow | |
| import matplotlib.pyplot as plt | |
| from tqdm import tqdm | |
| import numpy as np | |
| from skimage.color import rgb2lab, lab2rgb | |
| from datasets import load_dataset | |
| from PIL import Image | |
| from itertools import islice | |
| import traceback | |
| # MLflow setup | |
| EXPERIMENT_NAME = "Colorizer_Experiment" | |
| def setup_mlflow(): | |
| experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME) | |
| if experiment is None: | |
| experiment_id = mlflow.create_experiment(EXPERIMENT_NAME) | |
| else: | |
| experiment_id = experiment.experiment_id | |
| return experiment_id | |
| # Data ingestion | |
| class ColorizeIterableDataset(IterableDataset): | |
| def __init__(self, dataset, transform=None): | |
| self.dataset = dataset | |
| self.transform = transform | |
| def __iter__(self): | |
| for item in self.dataset: | |
| try: | |
| img = item['image'] | |
| if img.mode != 'RGB': | |
| img = img.convert('RGB') | |
| if self.transform: | |
| img = self.transform(img) | |
| # Add shape check after transform | |
| if img.shape != (3, 256, 256): | |
| print(f"Unexpected image shape after transform: {img.shape}") | |
| continue | |
| lab = rgb2lab(img.permute(1, 2, 0).numpy()) | |
| # Add shape check after rgb2lab conversion | |
| if lab.shape != (256, 256, 3): | |
| print(f"Unexpected lab shape: {lab.shape}") | |
| continue | |
| l_chan = lab[:, :, 0] | |
| l_chan = (l_chan - 50) / 50 | |
| ab_chan = lab[:, :, 1:] | |
| ab_chan = ab_chan / 128 | |
| yield torch.Tensor(l_chan).unsqueeze(0), torch.Tensor(ab_chan).permute(2, 0, 1) | |
| except Exception as e: | |
| print(f"Error processing image: {str(e)}") | |
| continue | |
| def create_dataloaders(batch_size=32): | |
| try: | |
| print("Loading ImageNet dataset in streaming mode...") | |
| dataset = load_dataset("imagenet-1k", split="train", streaming=True) | |
| print("Dataset loaded in streaming mode.") | |
| print("Creating custom dataset...") | |
| transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), # Resize all images to 256x256 | |
| transforms.ToTensor() | |
| ]) | |
| train_dataset = ColorizeIterableDataset(dataset, transform=transform) | |
| print("Custom dataset created.") | |
| print("Creating dataloader...") | |
| train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=4) | |
| print("Dataloader created.") | |
| return train_dataloader | |
| except Exception as e: | |
| print(f"Error in create_dataloaders: {str(e)}") | |
| return None | |
| # Model definition | |
| class UNetBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, down=True, bn=True, dropout=False): | |
| super(UNetBlock, self).__init__() | |
| self.conv = nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False) if down \ | |
| else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False) | |
| self.bn = nn.BatchNorm2d(out_channels) if bn else None | |
| self.dropout = nn.Dropout(0.5) if dropout else None | |
| self.down = down | |
| def forward(self, x): | |
| x = self.conv(x) | |
| if self.bn: | |
| x = self.bn(x) | |
| if self.dropout: | |
| x = self.dropout(x) | |
| return nn.ReLU()(x) if self.down else nn.ReLU(inplace=True)(x) | |
| class Generator(nn.Module): | |
| def __init__(self): | |
| super(Generator, self).__init__() | |
| self.down1 = UNetBlock(1, 64, bn=False) | |
| self.down2 = UNetBlock(64, 128) | |
| self.down3 = UNetBlock(128, 256) | |
| self.down4 = UNetBlock(256, 512) | |
| self.down5 = UNetBlock(512, 512) | |
| self.down6 = UNetBlock(512, 512) | |
| self.down7 = UNetBlock(512, 512) | |
| self.down8 = UNetBlock(512, 512, bn=False) | |
| self.up1 = UNetBlock(512, 512, down=False, dropout=True) | |
| self.up2 = UNetBlock(1024, 512, down=False, dropout=True) | |
| self.up3 = UNetBlock(1024, 512, down=False, dropout=True) | |
| self.up4 = UNetBlock(1024, 512, down=False) | |
| self.up5 = UNetBlock(1024, 256, down=False) | |
| self.up6 = UNetBlock(512, 128, down=False) | |
| self.up7 = UNetBlock(256, 64, down=False) | |
| self.up8 = nn.ConvTranspose2d(128, 2, 4, 2, 1) | |
| def forward(self, x): | |
| d1 = self.down1(x) | |
| d2 = self.down2(d1) | |
| d3 = self.down3(d2) | |
| d4 = self.down4(d3) | |
| d5 = self.down5(d4) | |
| d6 = self.down6(d5) | |
| d7 = self.down7(d6) | |
| d8 = self.down8(d7) | |
| u1 = self.up1(d8) | |
| u2 = self.up2(torch.cat([u1, d7], 1)) | |
| u3 = self.up3(torch.cat([u2, d6], 1)) | |
| u4 = self.up4(torch.cat([u3, d5], 1)) | |
| u5 = self.up5(torch.cat([u4, d4], 1)) | |
| u6 = self.up6(torch.cat([u5, d3], 1)) | |
| u7 = self.up7(torch.cat([u6, d2], 1)) | |
| return torch.tanh(self.up8(torch.cat([u7, d1], 1))) | |
| class Discriminator(nn.Module): | |
| def __init__(self): | |
| super(Discriminator, self).__init__() | |
| self.model = nn.Sequential( | |
| nn.Conv2d(3, 64, 4, stride=2, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(64, 128, 4, stride=2, padding=1), | |
| nn.BatchNorm2d(128), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(128, 256, 4, stride=2, padding=1), | |
| nn.BatchNorm2d(256), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(256, 512, 4, padding=1), | |
| nn.BatchNorm2d(512), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(512, 1, 4, padding=1) | |
| ) | |
| def forward(self, x): | |
| return self.model(x) | |
| def init_weights(model): | |
| classname = model.__class__.__name__ | |
| if classname.find('Conv') != -1: | |
| nn.init.normal_(model.weight.data, 0.0, 0.02) | |
| elif classname.find('BatchNorm') != -1: | |
| nn.init.normal_(model.weight.data, 1.0, 0.02) | |
| nn.init.constant_(model.bias.data, 0) | |
| # Training utilities | |
| def lab_to_rgb(L, ab): | |
| L = (L + 1.) * 50. | |
| ab = ab * 128. | |
| Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy() | |
| rgb_imgs = [] | |
| for img in Lab: | |
| img_rgb = lab2rgb(img) | |
| rgb_imgs.append(img_rgb) | |
| return np.stack(rgb_imgs, axis=0) | |
| def visualize_results(epoch, generator, train_loader, device): | |
| generator.eval() | |
| with torch.no_grad(): | |
| try: | |
| for inputs, real_AB in train_loader: | |
| print(f"Input shape: {inputs.shape}, real_AB shape: {real_AB.shape}") | |
| # Ensure inputs have the correct shape (B, 1, H, W) | |
| if inputs.shape[1] != 1: | |
| inputs = inputs.unsqueeze(1) | |
| inputs, real_AB = inputs.to(device), real_AB.to(device) | |
| fake_AB = generator(inputs) | |
| print(f"fake_AB shape: {fake_AB.shape}") | |
| # Ensure fake_AB and real_AB have the correct shape (B, 2, H, W) | |
| if fake_AB.shape[1] != 2: | |
| fake_AB = fake_AB.view(fake_AB.shape[0], 2, fake_AB.shape[2], fake_AB.shape[3]) | |
| if real_AB.shape[1] != 2: | |
| real_AB = real_AB.view(real_AB.shape[0], 2, real_AB.shape[2], real_AB.shape[3]) | |
| fake_rgb = lab_to_rgb(inputs.cpu(), fake_AB.cpu()) | |
| real_rgb = lab_to_rgb(inputs.cpu(), real_AB.cpu()) | |
| print(f"fake_rgb shape: {fake_rgb.shape}, real_rgb shape: {real_rgb.shape}") | |
| concatenated = np.concatenate([real_rgb, fake_rgb], axis=2) # Changed axis from 3 to 2 | |
| print(f"Concatenated shape: {concatenated.shape}") | |
| img_grid = make_grid(torch.from_numpy(concatenated).permute(0, 3, 1, 2), normalize=True, nrow=4) | |
| plt.figure(figsize=(15, 15)) | |
| plt.imshow(img_grid.permute(1, 2, 0).cpu()) | |
| plt.axis('off') | |
| plt.title(f'Epoch {epoch}') | |
| plt.savefig(f'results/epoch_{epoch}.png') | |
| mlflow.log_artifact(f'results/epoch_{epoch}.png') | |
| plt.close() | |
| break | |
| except Exception as e: | |
| print(f"Error in visualize_results: {str(e)}") | |
| traceback.print_exc() | |
| generator.train() | |
| def save_checkpoint(state, filename="checkpoint.pth.tar"): | |
| # Only save the necessary state | |
| save_state = { | |
| 'epoch': state['epoch'], | |
| 'generator_state_dict': state['generator_state_dict'], | |
| 'discriminator_state_dict': state['discriminator_state_dict'], | |
| 'optimizerG_state_dict': state['optimizerG_state_dict'], | |
| 'optimizerD_state_dict': state['optimizerD_state_dict'], | |
| } | |
| torch.save(save_state, filename) | |
| mlflow.log_artifact(filename) | |
| def load_checkpoint(filename, generator, discriminator, optimizerG, optimizerD, device): | |
| if os.path.isfile(filename): | |
| print(f"Loading checkpoint '{filename}'") | |
| # Use weights_only=True for safer loading | |
| checkpoint = torch.load(filename, map_location=device, weights_only=True) | |
| start_epoch = checkpoint['epoch'] + 1 | |
| generator.load_state_dict(checkpoint['generator_state_dict']) | |
| discriminator.load_state_dict(checkpoint['discriminator_state_dict']) | |
| optimizerG.load_state_dict(checkpoint['optimizerG_state_dict']) | |
| optimizerD.load_state_dict(checkpoint['optimizerD_state_dict']) | |
| print(f"Loaded checkpoint '{filename}' (epoch {checkpoint['epoch']})") | |
| return start_epoch | |
| else: | |
| print(f"No checkpoint found at '{filename}'") | |
| return 0 | |
| # Global variables | |
| checkpoint_dir = "checkpoints" | |
| os.makedirs(checkpoint_dir, exist_ok=True) | |
| os.makedirs("results", exist_ok=True) | |
| # Training function | |
| def train(generator, discriminator, train_loader, num_epochs, device, lr=0.0002, beta1=0.5): | |
| criterion = nn.BCEWithLogitsLoss() | |
| l1_loss = nn.L1Loss() | |
| optimizerG = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999)) | |
| optimizerD = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999)) | |
| checkpoint_path = os.path.join(checkpoint_dir, "latest_checkpoint.pth.tar") | |
| start_epoch = load_checkpoint(checkpoint_path, generator, discriminator, optimizerG, optimizerD, device) | |
| experiment_id = setup_mlflow() | |
| with mlflow.start_run(experiment_id=experiment_id, run_name="training_run") as run: | |
| try: | |
| for epoch in range(start_epoch, num_epochs): | |
| generator.train() | |
| discriminator.train() | |
| num_iterations = 2000 | |
| pbar = tqdm(enumerate(islice(train_loader, num_iterations)), total=num_iterations, desc=f"Epoch {epoch+1}/{num_epochs}") | |
| for i, (real_L, real_AB) in pbar: | |
| # Add shape check | |
| if real_L.shape[1:] != (1, 256, 256) or real_AB.shape[1:] != (2, 256, 256): | |
| print(f"Unexpected tensor shapes: real_L {real_L.shape}, real_AB {real_AB.shape}") | |
| continue | |
| real_L, real_AB = real_L.to(device), real_AB.to(device) | |
| batch_size = real_L.size(0) | |
| # Train Discriminator | |
| optimizerD.zero_grad() | |
| fake_AB = generator(real_L) | |
| fake_LAB = torch.cat([real_L, fake_AB], dim=1) | |
| real_LAB = torch.cat([real_L, real_AB], dim=1) | |
| pred_fake = discriminator(fake_LAB.detach()) | |
| loss_D_fake = criterion(pred_fake, torch.zeros_like(pred_fake)) | |
| pred_real = discriminator(real_LAB) | |
| loss_D_real = criterion(pred_real, torch.ones_like(pred_real)) | |
| loss_D = (loss_D_fake + loss_D_real) * 0.5 | |
| loss_D.backward() | |
| optimizerD.step() | |
| # Train Generator | |
| optimizerG.zero_grad() | |
| fake_AB = generator(real_L) | |
| fake_LAB = torch.cat([real_L, fake_AB], dim=1) | |
| pred_fake = discriminator(fake_LAB) | |
| loss_G_GAN = criterion(pred_fake, torch.ones_like(pred_fake)) | |
| loss_G_L1 = l1_loss(fake_AB, real_AB) * 100 # L1 loss weight | |
| loss_G = loss_G_GAN + loss_G_L1 | |
| loss_G.backward() | |
| optimizerG.step() | |
| pbar.set_postfix({ | |
| 'D_loss': loss_D.item(), | |
| 'G_loss': loss_G.item(), | |
| 'G_L1': loss_G_L1.item() | |
| }) | |
| mlflow.log_metrics({ | |
| "D_loss": loss_D.item(), | |
| "G_loss": loss_G.item(), | |
| "G_L1_loss": loss_G_L1.item() | |
| }, step=epoch * num_iterations + i) | |
| visualize_results(epoch, generator, train_loader, device) | |
| checkpoint = { | |
| 'epoch': epoch, | |
| 'generator_state_dict': generator.state_dict(), | |
| 'discriminator_state_dict': discriminator.state_dict(), | |
| 'optimizerG_state_dict': optimizerG.state_dict(), | |
| 'optimizerD_state_dict': optimizerD.state_dict(), | |
| } | |
| save_checkpoint(checkpoint, filename=checkpoint_path) | |
| print("Training completed successfully.") | |
| mlflow.pytorch.log_model(generator, "generator_model") | |
| model_uri = f"runs:/{run.info.run_id}/generator_model" | |
| mlflow.register_model(model_uri, "colorizer_generator") | |
| return run.info.run_id | |
| except Exception as e: | |
| print(f"Error during training: {str(e)}") | |
| mlflow.log_param("error", str(e)) | |
| return None | |
| # Main execution | |
| if __name__ == "__main__": | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| try: | |
| batch_size = 32 | |
| num_epochs = 50 | |
| train_loader = create_dataloaders(batch_size=batch_size) | |
| if train_loader is None: | |
| raise Exception("Failed to create dataloader") | |
| generator = Generator().to(device) | |
| discriminator = Discriminator().to(device) | |
| generator.apply(init_weights) | |
| discriminator.apply(init_weights) | |
| run_id = train(generator, discriminator, train_loader, num_epochs=num_epochs, device=device) | |
| if run_id: | |
| print(f"Training completed successfully. Run ID: {run_id}") | |
| with open("latest_run_id.txt", "w") as f: | |
| f.write(run_id) | |
| else: | |
| print("Training failed!") | |
| except Exception as e: | |
| print(f"Critical error in main execution: {str(e)}") |