Spaces:
Sleeping
Sleeping
| import torch | |
| from torch.utils.data import random_split, DataLoader | |
| import pytorch_lightning as pl | |
| from pytorch_lightning.loggers import MLFlowLogger | |
| from src.trainer import LitTrainer | |
| def argmax(a): | |
| return max(range(len(a)), key=lambda x: a[x]) | |
| def get_dataloaders(dataset, test_data): | |
| train_size = round(len(dataset) * 0.8) | |
| validate_size = len(dataset) - train_size | |
| train_data, validate_data = random_split(dataset, [train_size, validate_size]) | |
| # For 8 CPU cores | |
| return DataLoader(train_data, num_workers=8), \ | |
| DataLoader(validate_data, num_workers=8), \ | |
| DataLoader(test_data, num_workers=8) | |
| def train_loop(net, batch, loss_fn, optim, device="cuda"): | |
| x, y = batch | |
| x = x.to(device) | |
| y = y.to(device) | |
| y_pred = net(x).reshape(1, -1) | |
| loss = loss_fn(y_pred, y) | |
| truth_count = argmax(y_pred.flatten()) == y | |
| optim.zero_grad() | |
| loss.backward() | |
| optim.step() | |
| return loss.item(), truth_count | |
| def train_net_manually(net, optim, loss_fn, train_loader, validate_loader=None, epochs=10, device="cuda"): | |
| for i in range(epochs): | |
| print("Epoch: 0") | |
| epoch_loss = 0 | |
| epoch_truth_count = 0 | |
| for idx, batch in enumerate(train_loader): | |
| loss, truth_count = train_loop(net, batch, loss_fn, optim, device) | |
| epoch_loss += loss | |
| epoch_truth_count += truth_count | |
| if idx % 1000 == 0: | |
| print(f"Loss: {loss} ({idx} / {len(train_loader)} x {i})") | |
| print(f"Epoch Loss: {epoch_loss}") | |
| print(f"Epoch Accuracy: {epoch_truth_count / len(train_loader)}") | |
| torch.save(net.state_dict(), "checkpoints/pytorch/version_1.pt") | |
| def train_net_lightning(net, optim, loss_fn, train_loader, validate_loader=None, epochs=10): | |
| logger = MLFlowLogger(experiment_name="lightning_logs", tracking_uri="file:./ml-runs") | |
| pl_net = LitTrainer(net) | |
| pl_net.optim = optim | |
| pl_net.loss = loss_fn | |
| trainer = pl.Trainer(limit_train_batches=100, max_epochs=epochs, | |
| default_root_dir="../checkpoints", logger=logger) | |
| trainer.fit(pl_net, train_loader, validate_loader) | |