Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # coding: utf-8 | |
| from torch import nn, optim | |
| import pytorch_lightning as pl | |
| class LitTrainer(pl.LightningModule): | |
| def __init__(self, model): | |
| super().__init__() | |
| self.model = model | |
| self.optim = optim.Adam(self.parameters(), lr=1e-4) | |
| self.loss = nn.CrossEntropyLoss() | |
| def training_step(self, batch, batch_idx): | |
| x, y = batch | |
| y_pred = self.model(x).reshape(1, -1) | |
| train_loss = self.loss(y_pred, y) | |
| self.log("train_loss", train_loss) | |
| return train_loss | |
| def validation_step(self, batch, batch_idx): | |
| # this is the validation loop | |
| x, y = batch | |
| y_pred = self.model(x).reshape(1, -1) | |
| validate_loss = self.loss(y_pred, y) | |
| self.log("val_loss", validate_loss) | |
| def test_step(self, batch, batch_idx): | |
| # this is the test loop | |
| x, y = batch | |
| y_pred = self.model(x).reshape(1, -1) | |
| test_loss = self.loss(y_pred, y) | |
| self.log("test_loss", test_loss) | |
| def forward(self, x): | |
| return self.model(x) | |
| def configure_optimizers(self): | |
| return self.optim | |