Skip to content

Commit e2ead9a

Browse files
authored
Refactor some loops code and hook tests (#7682)
1 parent 8ba6304 commit e2ead9a

File tree

6 files changed

+133
-219
lines changed

6 files changed

+133
-219
lines changed

pytorch_lightning/trainer/connectors/optimizer_connector.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,30 +11,30 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Dict, List, Optional
14+
from typing import List, Optional
15+
from weakref import proxy
1516

17+
import pytorch_lightning as pl
1618
from pytorch_lightning.utilities import rank_zero_warn
1719
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1820

1921

2022
class OptimizerConnector:
2123

22-
def __init__(self, trainer):
23-
self.trainer = trainer
24+
def __init__(self, trainer: 'pl.Trainer') -> None:
25+
self.trainer = proxy(trainer)
2426

25-
def on_trainer_init(self):
27+
def on_trainer_init(self) -> None:
2628
self.trainer.lr_schedulers = []
2729
self.trainer.optimizers = []
2830
self.trainer.optimizer_frequencies = []
2931

30-
def update_learning_rates(
31-
self, interval: str, monitor_metrics: Optional[Dict[str, Any]] = None, opt_indices: Optional[List[int]] = None
32-
):
32+
def update_learning_rates(self, interval: str, opt_indices: Optional[List[int]] = None) -> None:
3333
"""Update learning rates.
3434
3535
Args:
3636
interval: either 'epoch' or 'step'.
37-
monitor_metrics: dict of possible values to monitor
37+
opt_indices: indices of the optimizers to update.
3838
"""
3939
if not self.trainer.lr_schedulers or not self.trainer.lightning_module.automatic_optimization:
4040
return
@@ -55,10 +55,7 @@ def update_learning_rates(
5555
monitor_key, monitor_val = None, None
5656
if lr_scheduler['reduce_on_plateau']:
5757
monitor_key = lr_scheduler['monitor']
58-
monitor_val = (
59-
monitor_metrics.get(monitor_key) if monitor_metrics is not None else
60-
self.trainer.logger_connector.callback_metrics.get(monitor_key)
61-
)
58+
monitor_val = self.trainer.logger_connector.callback_metrics.get(monitor_key)
6259
if monitor_val is None:
6360
if lr_scheduler.get('strict', True):
6461
avail_metrics = list(self.trainer.logger_connector.callback_metrics.keys())

pytorch_lightning/trainer/training_loop.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from collections import OrderedDict
1616
from contextlib import contextmanager, suppress
17-
from copy import copy, deepcopy
17+
from copy import copy
1818
from functools import partial, update_wrapper
1919
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2020

@@ -478,7 +478,6 @@ def run_training_epoch(self):
478478

479479
train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader)
480480
dataloader_idx = 0
481-
482481
batch_idx = None
483482
is_last_batch = None
484483

@@ -525,8 +524,7 @@ def run_training_epoch(self):
525524
self.save_loggers_on_train_batch_end()
526525

527526
# update LR schedulers
528-
monitor_metrics = deepcopy(self.trainer.logger_connector.callback_metrics)
529-
self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics)
527+
self.update_lr_schedulers('step')
530528
self.trainer.checkpoint_connector.has_trained = True
531529

532530
self.total_batch_idx += 1
@@ -567,7 +565,7 @@ def run_training_epoch(self):
567565

568566
# update epoch level lr_schedulers if no val loop outside train loop is triggered
569567
if not should_check_val or should_train_only:
570-
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')
568+
self.update_lr_schedulers('epoch')
571569

572570
if should_train_only:
573571
self.check_checkpoint_callback(True)
@@ -863,17 +861,16 @@ def backward(self, result, optimizer, opt_idx, *args, **kwargs):
863861
# track gradients
864862
result.grad_norm_dict = self.track_and_norm_grad(optimizer=optimizer)
865863

866-
def update_train_loop_lr_schedulers(self, monitor_metrics=None):
867-
num_accumulated_batches_reached = self._accumulated_batches_reached()
868-
num_training_batches_reached = self._num_training_batches_reached()
869-
870-
if num_accumulated_batches_reached or num_training_batches_reached:
871-
# update lr
872-
self.trainer.optimizer_connector.update_learning_rates(
873-
interval="step",
874-
monitor_metrics=monitor_metrics,
875-
opt_indices=[opt_idx for opt_idx, _ in self.get_active_optimizers()],
876-
)
864+
def update_lr_schedulers(self, interval: str) -> None:
865+
if interval == "step":
866+
finished_accumulation = self._accumulated_batches_reached()
867+
finished_epoch = self._num_training_batches_reached()
868+
if not finished_accumulation and not finished_epoch:
869+
return
870+
self.trainer.optimizer_connector.update_learning_rates(
871+
interval=interval,
872+
opt_indices=[opt_idx for opt_idx, _ in self.get_active_optimizers()],
873+
)
877874

878875
def increment_accumulated_grad_global_step(self):
879876
num_accumulated_batches_reached = self._accumulated_batches_reached()
@@ -897,15 +894,21 @@ def should_accumulate(self):
897894

898895
def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bool = False) -> bool:
899896
""" Decide if we should run validation. """
900-
901897
if not self.trainer.enable_validation:
902898
return False
903899

904-
# check if this epoch is eligible to run validation
905-
if (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch != 0:
900+
is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
901+
if not is_val_check_epoch:
906902
return False
907903

908904
# val_check_batch is inf for iterable datasets with no length defined
905+
is_infinite_dataset = self.trainer.val_check_batch == float('inf')
906+
if on_epoch and is_last_batch and is_infinite_dataset:
907+
return True
908+
909+
if on_epoch and self.trainer.should_stop:
910+
return True
911+
909912
# TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch
910913
is_val_check_batch = False
911914
if isinstance(self.trainer.limit_train_batches, int) and self.trainer.val_check_batch == float('inf'):
@@ -915,12 +918,9 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bo
915918

916919
# Note: num_training_batches is also inf for iterable datasets with no length defined
917920
epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0
918-
is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf")
919921

920922
if on_epoch:
921-
return (
922-
is_val_check_batch and epoch_end_val_check
923-
) or self.trainer.should_stop or is_last_batch_for_infinite_dataset
923+
return is_val_check_batch and epoch_end_val_check
924924
else:
925925
return is_val_check_batch and not epoch_end_val_check
926926

tests/callbacks/test_early_stopping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def test_early_stopping_patience_train(
157157
"""Test to ensure that early stopping is not triggered before patience is exhausted."""
158158

159159
class ModelOverrideTrainReturn(BoringModel):
160-
train_return_values = torch.Tensor(loss_values)
160+
train_return_values = torch.tensor(loss_values)
161161

162162
def training_epoch_end(self, outputs):
163163
loss = self.train_return_values[self.current_epoch]

tests/loggers/test_tensorboard.py

Lines changed: 12 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -264,67 +264,42 @@ def test_tensorboard_log_graph_warning_no_example_input_array(tmpdir):
264264

265265

266266
@mock.patch('pytorch_lightning.loggers.TensorBoardLogger.log_metrics')
267-
@pytest.mark.parametrize('expected', [
268-
([5, 11, 17]),
269-
])
270-
def test_tensorboard_with_accummulated_gradients(mock_log_metrics, expected, tmpdir):
271-
"""
272-
Tests to ensure that tensorboard log properly when accumulated_gradients > 1
273-
"""
267+
def test_tensorboard_with_accummulated_gradients(mock_log_metrics, tmpdir):
268+
"""Tests to ensure that tensorboard log properly when accumulated_gradients > 1"""
274269

275270
class TestModel(BoringModel):
276271

277272
def __init__(self):
278273
super().__init__()
279-
self._count = 0
280-
self._indexes = []
281-
282-
def training_step(self, batch, batch_idx):
283-
output = self.layer(batch)
284-
loss = self.loss(batch, output)
285-
self.log('count', self._count, on_step=True, on_epoch=True)
286-
self.log('loss', loss, on_step=True, on_epoch=True)
274+
self.indexes = []
287275

276+
def training_step(self, *args):
277+
self.log('foo', 1, on_step=True, on_epoch=True)
288278
if not self.trainer.train_loop.should_accumulate():
289279
if self.trainer.logger_connector.should_update_logs:
290-
self._indexes.append(self.trainer.global_step)
291-
292-
return loss
293-
294-
def validation_step(self, batch, batch_idx):
295-
output = self.layer(batch)
296-
loss = self.loss(batch, output)
297-
self.log('val_loss', loss, on_step=True, on_epoch=True)
298-
return loss
299-
300-
def configure_optimizers(self):
301-
optimizer = torch.optim.SGD(self.layer.parameters(), lr=.001)
302-
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
303-
return [optimizer], [lr_scheduler]
280+
self.indexes.append(self.trainer.global_step)
281+
return super().training_step(*args)
304282

305283
model = TestModel()
306284
model.training_epoch_end = None
307-
model.validation_epoch_end = None
308-
309285
logger_0 = TensorBoardLogger(tmpdir, default_hp_metric=False)
310-
311286
trainer = Trainer(
312287
default_root_dir=tmpdir,
313288
limit_train_batches=12,
314289
limit_val_batches=0,
315290
max_epochs=3,
316-
gpus=0,
317291
accumulate_grad_batches=2,
318292
logger=[logger_0],
319293
log_every_n_steps=3,
320294
)
321295
trainer.fit(model)
322296

323-
mock_count_epochs = [m[2]["step"] for m in mock_log_metrics.mock_calls if "count_epoch" in m[2]["metrics"]]
324-
assert mock_count_epochs == expected
297+
calls = [m[2] for m in mock_log_metrics.mock_calls]
298+
count_epochs = [c["step"] for c in calls if "foo_epoch" in c["metrics"]]
299+
assert count_epochs == [5, 11, 17]
325300

326-
mock_count_steps = [m[2]["step"] for m in mock_log_metrics.mock_calls if "count_step" in m[2]["metrics"]]
327-
assert model._indexes == mock_count_steps
301+
count_steps = [c["step"] for c in calls if "foo_step" in c["metrics"]]
302+
assert count_steps == model.indexes
328303

329304

330305
@mock.patch('pytorch_lightning.loggers.tensorboard.SummaryWriter')

0 commit comments

Comments
 (0)