diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 69c764e06e48f..883eed320f80d 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -99,7 +99,7 @@ def configure_accumulated_gradients(self, accumulate_grad_batches): if isinstance(accumulate_grad_batches, dict): self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches) elif isinstance(accumulate_grad_batches, int): - schedule = {1: accumulate_grad_batches} + schedule = {0: accumulate_grad_batches} self.accumulation_scheduler = GradientAccumulationScheduler(schedule) else: raise TypeError("Gradient accumulation supports only int and dict types") diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py index b9f955ed22331..ed0421a077f2f 100755 --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -154,7 +154,7 @@ def test_accumulation_and_early_stopping(tmpdir): 'Learning rate was not altered after running learning rate finder' assert len(lrfinder.results['lr']) == 100, \ 'Early stopping for learning rate finder did not work' - assert lrfinder._total_batch_idx == 190, \ + assert lrfinder._total_batch_idx == 100 * 2, \ 'Accumulation parameter did not work' diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 68b41d65471b0..a078c40d48a93 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -103,7 +103,15 @@ def test_no_val_end_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): model_2.eval() -def test_gradient_accumulation_scheduling(tmpdir): +@pytest.mark.parametrize( + ['schedule', 'expected'], + [ + pytest.param({1: 2, 3: 4}, [1, 2, 4]), + pytest.param(3, [3, 3, 3]), + pytest.param(4, [4, 4, 4]) + ] +) +def test_gradient_accumulation_scheduling(tmpdir, schedule, expected): """ Test grad accumulation by the freq of optimizer updates """ @@ -123,59 +131,59 @@ def test_gradient_accumulation_scheduling(tmpdir): with pytest.raises(TypeError): assert Trainer(accumulate_grad_batches={1: 2.5, 3: 5}) + model = EvalModelTemplate() + + trainer = Trainer(accumulate_grad_batches=schedule, + limit_train_batches=0.8, + limit_val_batches=0.8, + max_epochs=4, + default_root_dir=tmpdir) + # test optimizer call freq matches scheduler - def _optimizer_step(self, epoch, batch_idx, optimizer, - optimizer_idx, second_order_closure=None): + def _optimizer_step(epoch, batch_idx, optimizer, optimizer_idx, + second_order_closure=None, on_tpu=False, + using_native_amp=False, using_lbfgs=False): # only test the first 12 batches in epoch if batch_idx < 12: if epoch == 0: # reset counter when starting epoch - if batch_idx == 0: - self.prev_called_batch_idx = 0 + if batch_idx == expected[0] - 1: + model.prev_called_batch_idx = expected[0] - 1 # use this opportunity to test once - assert self.trainer.accumulate_grad_batches == 1 + assert trainer.accumulate_grad_batches == expected[0] - assert batch_idx == self.prev_called_batch_idx - self.prev_called_batch_idx += 1 + assert batch_idx == model.prev_called_batch_idx + model.prev_called_batch_idx += expected[0] elif 1 <= epoch <= 2: # reset counter when starting epoch - if batch_idx == 1: - self.prev_called_batch_idx = 1 + if batch_idx == expected[1] - 1: + model.prev_called_batch_idx = expected[1] - 1 # use this opportunity to test once - assert self.trainer.accumulate_grad_batches == 2 + assert trainer.accumulate_grad_batches == expected[1] - assert batch_idx == self.prev_called_batch_idx - self.prev_called_batch_idx += 2 + assert batch_idx == model.prev_called_batch_idx + model.prev_called_batch_idx += expected[1] else: - if batch_idx == 3: - self.prev_called_batch_idx = 3 + if batch_idx == expected[2] - 1: + model.prev_called_batch_idx = expected[2] - 1 # use this opportunity to test once - assert self.trainer.accumulate_grad_batches == 4 + assert trainer.accumulate_grad_batches == expected[2] - assert batch_idx == self.prev_called_batch_idx - self.prev_called_batch_idx += 3 + assert batch_idx == model.prev_called_batch_idx + model.prev_called_batch_idx += expected[2] optimizer.step() # clear gradients optimizer.zero_grad() - model = EvalModelTemplate() - schedule = {1: 2, 3: 4} - - trainer = Trainer(accumulate_grad_batches=schedule, - limit_train_batches=0.1, - limit_val_batches=0.1, - max_epochs=2, - default_root_dir=tmpdir) - # for the test - trainer.optimizer_step = _optimizer_step + model.optimizer_step = _optimizer_step model.prev_called_batch_idx = 0 trainer.fit(model)