diff --git a/mnist/main.py b/mnist/main.py index 9e9db4374d..4abbac2f6d 100644 --- a/mnist/main.py +++ b/mnist/main.py @@ -49,7 +49,7 @@ def train(args, model, device, train_loader, optimizer, epoch): 100. * batch_idx / len(train_loader), loss.item())) -def test(args, model, device, test_loader): +def test(model, device, test_loader): model.eval() test_loss = 0 correct = 0 @@ -118,7 +118,7 @@ def main(): scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) for epoch in range(1, args.epochs + 1): train(args, model, device, train_loader, optimizer, epoch) - test(args, model, device, test_loader) + test(model, device, test_loader) scheduler.step() if args.save_model: