File size: 742 Bytes
deb7039 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import torch
def train(model, trainloader, optimizer, criterion, DEVICE):
model.train()
running_loss = 0
for itr, data in enumerate(trainloader):
# print(itr)
# print(data[0].shape, data[1].shape)
# print(len(trainloader))
# if itr % 100 == 0:
# print("itr: {}".format(itr))
optimizer.zero_grad()
imgs, target = data[0].to(DEVICE), data[1].to(DEVICE)
output_logits = model(imgs)
loss = criterion( output_logits, target)
running_loss = loss.item()
loss.backward()
optimizer.step()
epoch_loss = running_loss/len(trainloader)
print("epoch loss = {}".format(epoch_loss))
return epoch_loss |