Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def configure_accumulated_gradients(self, accumulate_grad_batches):
if isinstance(accumulate_grad_batches, dict):
self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
elif isinstance(accumulate_grad_batches, int):
schedule = {1: accumulate_grad_batches}
schedule = {0: accumulate_grad_batches}
self.accumulation_scheduler = GradientAccumulationScheduler(schedule)
else:
raise TypeError("Gradient accumulation supports only int and dict types")
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_accumulation_and_early_stopping(tmpdir):
'Learning rate was not altered after running learning rate finder'
assert len(lrfinder.results['lr']) == 100, \
'Early stopping for learning rate finder did not work'
assert lrfinder._total_batch_idx == 190, \
assert lrfinder._total_batch_idx == 100 * 2, \
'Accumulation parameter did not work'


Expand Down
64 changes: 36 additions & 28 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,15 @@ def test_no_val_end_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
model_2.eval()


def test_gradient_accumulation_scheduling(tmpdir):
@pytest.mark.parametrize(
['schedule', 'expected'],
[
pytest.param({1: 2, 3: 4}, [1, 2, 4]),
pytest.param(3, [3, 3, 3]),
pytest.param(4, [4, 4, 4])
]
)
def test_gradient_accumulation_scheduling(tmpdir, schedule, expected):
"""
Test grad accumulation by the freq of optimizer updates
"""
Expand All @@ -123,59 +131,59 @@ def test_gradient_accumulation_scheduling(tmpdir):
with pytest.raises(TypeError):
assert Trainer(accumulate_grad_batches={1: 2.5, 3: 5})

model = EvalModelTemplate()

trainer = Trainer(accumulate_grad_batches=schedule,
limit_train_batches=0.8,
limit_val_batches=0.8,
max_epochs=4,
default_root_dir=tmpdir)

# test optimizer call freq matches scheduler
def _optimizer_step(self, epoch, batch_idx, optimizer,
optimizer_idx, second_order_closure=None):
def _optimizer_step(epoch, batch_idx, optimizer, optimizer_idx,
second_order_closure=None, on_tpu=False,
using_native_amp=False, using_lbfgs=False):
# only test the first 12 batches in epoch
if batch_idx < 12:
if epoch == 0:
# reset counter when starting epoch
if batch_idx == 0:
self.prev_called_batch_idx = 0
if batch_idx == expected[0] - 1:
model.prev_called_batch_idx = expected[0] - 1

# use this opportunity to test once
assert self.trainer.accumulate_grad_batches == 1
assert trainer.accumulate_grad_batches == expected[0]

assert batch_idx == self.prev_called_batch_idx
self.prev_called_batch_idx += 1
assert batch_idx == model.prev_called_batch_idx
model.prev_called_batch_idx += expected[0]

elif 1 <= epoch <= 2:
# reset counter when starting epoch
if batch_idx == 1:
self.prev_called_batch_idx = 1
if batch_idx == expected[1] - 1:
model.prev_called_batch_idx = expected[1] - 1

# use this opportunity to test once
assert self.trainer.accumulate_grad_batches == 2
assert trainer.accumulate_grad_batches == expected[1]

assert batch_idx == self.prev_called_batch_idx
self.prev_called_batch_idx += 2
assert batch_idx == model.prev_called_batch_idx
model.prev_called_batch_idx += expected[1]

else:
if batch_idx == 3:
self.prev_called_batch_idx = 3
if batch_idx == expected[2] - 1:
model.prev_called_batch_idx = expected[2] - 1

# use this opportunity to test once
assert self.trainer.accumulate_grad_batches == 4
assert trainer.accumulate_grad_batches == expected[2]

assert batch_idx == self.prev_called_batch_idx
self.prev_called_batch_idx += 3
assert batch_idx == model.prev_called_batch_idx
model.prev_called_batch_idx += expected[2]

optimizer.step()

# clear gradients
optimizer.zero_grad()

model = EvalModelTemplate()
schedule = {1: 2, 3: 4}

trainer = Trainer(accumulate_grad_batches=schedule,
limit_train_batches=0.1,
limit_val_batches=0.1,
max_epochs=2,
default_root_dir=tmpdir)

# for the test
trainer.optimizer_step = _optimizer_step
model.optimizer_step = _optimizer_step
model.prev_called_batch_idx = 0

trainer.fit(model)
Expand Down