Spaces:
Sleeping
Sleeping
| import os | |
| import mlflow | |
| import torch | |
| from torch.utils.data import IterableDataset, DataLoader | |
| from torchvision import transforms | |
| from datasets import load_dataset | |
| from skimage.color import rgb2lab | |
| from PIL import Image | |
| import numpy as np | |
| class ColorizeIterableDataset(IterableDataset): | |
| def __init__(self, dataset): | |
| self.dataset = dataset | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor() | |
| ]) | |
| def __iter__(self): | |
| for item in self.dataset: | |
| try: | |
| img = item['image'] | |
| if img.mode != 'RGB': | |
| img = img.convert('RGB') | |
| img = self.transform(img) | |
| # Convert to LAB color space | |
| lab = rgb2lab(img.permute(1, 2, 0).numpy()) | |
| # Normalize L channel to range [-1, 1] | |
| l_chan = lab[:, :, 0] | |
| l_chan = (l_chan - 50) / 50 | |
| # Normalize AB channels to range [-1, 1] | |
| ab_chan = lab[:, :, 1:] | |
| ab_chan = ab_chan / 128 | |
| yield torch.Tensor(l_chan).unsqueeze(0), torch.Tensor(ab_chan).permute(2, 0, 1) | |
| except Exception as e: | |
| print(f"Error processing image: {str(e)}") | |
| continue | |
| def create_dataloaders(batch_size=32): | |
| try: | |
| print("Loading ImageNet dataset in streaming mode...") | |
| # Load ImageNet dataset from Hugging Face in streaming mode | |
| dataset = load_dataset("imagenet-1k", split="train", streaming=True) | |
| print("Dataset loaded in streaming mode.") | |
| print("Creating custom dataset...") | |
| # Create custom dataset | |
| train_dataset = ColorizeIterableDataset(dataset) | |
| print("Custom dataset created.") | |
| print("Creating dataloader...") | |
| # Create dataloader | |
| train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=4) | |
| print("Dataloader created.") | |
| return train_dataloader | |
| except Exception as e: | |
| print(f"Error in create_dataloaders: {str(e)}") | |
| return None | |
| def test_data_ingestion(): | |
| print("Testing data ingestion...") | |
| try: | |
| dataloader = create_dataloaders(batch_size=4) | |
| if dataloader is None: | |
| raise Exception("Dataloader creation failed") | |
| # Get the first batch | |
| for sample_batch in dataloader: | |
| if len(sample_batch) != 2: | |
| raise Exception(f"Unexpected batch format: {len(sample_batch)} elements instead of 2") | |
| l_chan, ab_chan = sample_batch | |
| if l_chan.shape != torch.Size([4, 1, 256, 256]) or ab_chan.shape != torch.Size([4, 2, 256, 256]): | |
| raise Exception(f"Unexpected tensor shapes: L={l_chan.shape}, AB={ab_chan.shape}") | |
| print("Data ingestion test passed.") | |
| return True | |
| except Exception as e: | |
| print(f"Data ingestion test failed: {str(e)}") | |
| return False | |
| if __name__ == "__main__": | |
| try: | |
| print("Starting data ingestion pipeline...") | |
| mlflow.start_run(run_name="data_ingestion") | |
| try: | |
| # Log parameters | |
| print("Logging parameters...") | |
| mlflow.log_param("batch_size", 32) | |
| mlflow.log_param("dataset", "imagenet-1k") | |
| print("Parameters logged.") | |
| # Create dataloaders | |
| print("Creating dataloaders...") | |
| train_dataloader = create_dataloaders(batch_size=32) | |
| if train_dataloader is None: | |
| raise Exception("Failed to create dataloader") | |
| print("Dataloaders created successfully.") | |
| # Log a sample batch | |
| print("Logging sample batch...") | |
| for sample_batch in train_dataloader: | |
| l_chan, ab_chan = sample_batch | |
| # Log sample input (L channel) | |
| sample_input = l_chan[0].numpy() | |
| mlflow.log_image(sample_input, "sample_input_l_channel.png") | |
| # Log sample target (AB channels) | |
| sample_target = ab_chan[0].permute(1, 2, 0).numpy() | |
| mlflow.log_image(sample_target, "sample_target_ab_channels.png") | |
| print("Sample batch logged.") | |
| break # We only need one batch for logging | |
| print("Data ingestion pipeline completed successfully.") | |
| except Exception as e: | |
| print(f"Error in data ingestion pipeline: {str(e)}") | |
| mlflow.log_param("error", str(e)) | |
| finally: | |
| mlflow.end_run() | |
| except Exception as e: | |
| print(f"Critical error in main execution: {str(e)}") | |
| if test_data_ingestion(): | |
| print("All tests passed.") | |
| else: | |
| print("Tests failed.") |