Spaces:
Runtime error
Runtime error
fixed issue with test acc on line 249
Browse files
app.py
CHANGED
|
@@ -239,15 +239,14 @@ def train_and_test(train_model=True):
|
|
| 239 |
# Train for one epoch and test
|
| 240 |
train_dataset = MNISTAdversarial_Dataset('./data_mnist',TRAIN_TRANSFORM)
|
| 241 |
|
| 242 |
-
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size_test, shuffle=True
|
| 243 |
-
)
|
| 244 |
train(n_epochs,network,optimizer,train_loader)
|
| 245 |
|
| 246 |
test_metric,test_acc = test()
|
| 247 |
|
| 248 |
if os.path.exists(METRIC_PATH):
|
| 249 |
metric_dict = read_json(METRIC_PATH)
|
| 250 |
-
metric_dict['all'] = metric_dict['all'] if 'all' in metric_dict else [] + [test_acc]
|
| 251 |
else:
|
| 252 |
metric_dict={}
|
| 253 |
metric_dict['all'] = [test_acc]
|
|
|
|
| 239 |
# Train for one epoch and test
|
| 240 |
train_dataset = MNISTAdversarial_Dataset('./data_mnist',TRAIN_TRANSFORM)
|
| 241 |
|
| 242 |
+
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size_test, shuffle=True)
|
|
|
|
| 243 |
train(n_epochs,network,optimizer,train_loader)
|
| 244 |
|
| 245 |
test_metric,test_acc = test()
|
| 246 |
|
| 247 |
if os.path.exists(METRIC_PATH):
|
| 248 |
metric_dict = read_json(METRIC_PATH)
|
| 249 |
+
metric_dict['all'] = metric_dict['all']+ [test_acc] if 'all' in metric_dict else [] + [test_acc]
|
| 250 |
else:
|
| 251 |
metric_dict={}
|
| 252 |
metric_dict['all'] = [test_acc]
|