Spaces:
Sleeping
Sleeping
| import argparse | |
| import os | |
| import torch | |
| import mlflow | |
| from data_ingestion import create_dataloaders, test_data_ingestion | |
| from model import Generator, Discriminator, init_weights, test_models | |
| from train import train, test_training | |
| from app import setup_gradio_app | |
| EXPERIMENT_NAME = "Colorizer_Experiment" | |
| def setup_mlflow(): | |
| experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME) | |
| if experiment is None: | |
| experiment_id = mlflow.create_experiment(EXPERIMENT_NAME) | |
| else: | |
| experiment_id = experiment.experiment_id | |
| return experiment_id | |
| def run_pipeline(args): | |
| device = torch.device(args.device) | |
| print(f"Using device: {device}") | |
| experiment_id = setup_mlflow() | |
| if args.ingest_data or args.run_all: | |
| print("Starting data ingestion...") | |
| train_loader = create_dataloaders(batch_size=args.batch_size) | |
| if train_loader is None: | |
| print("Data ingestion failed.") | |
| return | |
| else: | |
| train_loader = None | |
| if args.create_model or args.train or args.run_all: | |
| print("Creating and testing models...") | |
| generator = Generator().to(device) | |
| discriminator = Discriminator().to(device) | |
| generator.apply(init_weights) | |
| discriminator.apply(init_weights) | |
| if not test_models(): | |
| print("Model creation or testing failed.") | |
| return | |
| else: | |
| generator = None | |
| discriminator = None | |
| if args.train or args.run_all: | |
| print("Starting model training...") | |
| if train_loader is None: | |
| print("Creating dataloader for training...") | |
| train_loader = create_dataloaders(batch_size=args.batch_size) | |
| if train_loader is None: | |
| print("Failed to create dataloader for training.") | |
| return | |
| if generator is None or discriminator is None: | |
| print("Creating models for training...") | |
| generator = Generator().to(device) | |
| discriminator = Discriminator().to(device) | |
| generator.apply(init_weights) | |
| discriminator.apply(init_weights) | |
| run_id = train(generator, discriminator, train_loader, num_epochs=args.num_epochs, device=device) | |
| if run_id: | |
| print(f"Training completed. Run ID: {run_id}") | |
| with open("latest_run_id.txt", "w") as f: | |
| f.write(run_id) | |
| else: | |
| print("Training failed.") | |
| return | |
| if args.test_training: | |
| print("Testing training process...") | |
| if train_loader is None: | |
| print("Creating dataloader for testing...") | |
| train_loader = create_dataloaders(batch_size=args.batch_size) | |
| if train_loader is None: | |
| print("Failed to create dataloader for testing.") | |
| return | |
| if generator is None or discriminator is None: | |
| print("Creating models for testing...") | |
| generator = Generator().to(device) | |
| discriminator = Discriminator().to(device) | |
| generator.apply(init_weights) | |
| discriminator.apply(init_weights) | |
| if test_training(generator, discriminator, train_loader, device): | |
| print("Training process test passed.") | |
| else: | |
| print("Training process test failed.") | |
| if args.serve or args.run_all: | |
| print("Setting up Gradio app for serving...") | |
| if not args.run_id: | |
| try: | |
| with open("latest_run_id.txt", "r") as f: | |
| args.run_id = f.read().strip() | |
| except FileNotFoundError: | |
| print("No run ID provided and couldn't find latest_run_id.txt") | |
| return | |
| iface = setup_gradio_app(args.run_id, device) | |
| iface.launch(share=args.share) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Run Colorizer Pipeline") | |
| parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", | |
| help="Device to use (cuda/cpu)") | |
| parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training") | |
| parser.add_argument("--num_epochs", type=int, default=50, help="Number of epochs to train") | |
| parser.add_argument("--run_id", type=str, help="MLflow run ID of the trained model for inference") | |
| parser.add_argument("--ingest_data", action="store_true", help="Run data ingestion") | |
| parser.add_argument("--create_model", action="store_true", help="Create and test the model") | |
| parser.add_argument("--train", action="store_true", help="Train the model") | |
| parser.add_argument("--test_training", action="store_true", help="Test the training process") | |
| parser.add_argument("--serve", action="store_true", help="Serve the model using Gradio") | |
| parser.add_argument("--run_all", action="store_true", help="Run all steps") | |
| parser.add_argument("--share", action="store_true", help="Share the Gradio app publicly") | |
| args = parser.parse_args() | |
| run_pipeline(args) |