Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions pytorch_lightning/trainer/connectors/optimizer_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional
from typing import List, Optional
from weakref import proxy

import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class OptimizerConnector:

def __init__(self, trainer):
self.trainer = trainer
def __init__(self, trainer: 'pl.Trainer') -> None:
self.trainer = proxy(trainer)

def on_trainer_init(self):
def on_trainer_init(self) -> None:
self.trainer.lr_schedulers = []
self.trainer.optimizers = []
self.trainer.optimizer_frequencies = []

def update_learning_rates(
self, interval: str, monitor_metrics: Optional[Dict[str, Any]] = None, opt_indices: Optional[List[int]] = None
):
def update_learning_rates(self, interval: str, opt_indices: Optional[List[int]] = None) -> None:
"""Update learning rates.

Args:
interval: either 'epoch' or 'step'.
monitor_metrics: dict of possible values to monitor
opt_indices: indices of the optimizers to update.
"""
if not self.trainer.lr_schedulers or not self.trainer.lightning_module.automatic_optimization:
return
Expand All @@ -55,10 +55,7 @@ def update_learning_rates(
monitor_key, monitor_val = None, None
if lr_scheduler['reduce_on_plateau']:
monitor_key = lr_scheduler['monitor']
monitor_val = (
monitor_metrics.get(monitor_key) if monitor_metrics is not None else
self.trainer.logger_connector.callback_metrics.get(monitor_key)
)
monitor_val = self.trainer.logger_connector.callback_metrics.get(monitor_key)
if monitor_val is None:
if lr_scheduler.get('strict', True):
avail_metrics = list(self.trainer.logger_connector.callback_metrics.keys())
Expand Down
46 changes: 23 additions & 23 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from collections import OrderedDict
from contextlib import contextmanager, suppress
from copy import copy, deepcopy
from copy import copy
from functools import partial, update_wrapper
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -478,7 +478,6 @@ def run_training_epoch(self):

train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader)
dataloader_idx = 0

batch_idx = None
is_last_batch = None

Expand Down Expand Up @@ -525,8 +524,7 @@ def run_training_epoch(self):
self.save_loggers_on_train_batch_end()

# update LR schedulers
monitor_metrics = deepcopy(self.trainer.logger_connector.callback_metrics)
self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics)
self.update_lr_schedulers('step')
self.trainer.checkpoint_connector.has_trained = True

# max steps reached, end training
Expand Down Expand Up @@ -567,7 +565,7 @@ def run_training_epoch(self):

# update epoch level lr_schedulers if no val loop outside train loop is triggered
if not should_check_val or should_train_only:
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')
self.update_lr_schedulers('epoch')

if should_train_only:
self.check_checkpoint_callback(True)
Expand Down Expand Up @@ -863,17 +861,16 @@ def backward(self, result, optimizer, opt_idx, *args, **kwargs):
# track gradients
result.grad_norm_dict = self.track_and_norm_grad(optimizer=optimizer)

def update_train_loop_lr_schedulers(self, monitor_metrics=None):
num_accumulated_batches_reached = self._accumulated_batches_reached()
num_training_batches_reached = self._num_training_batches_reached()

if num_accumulated_batches_reached or num_training_batches_reached:
# update lr
self.trainer.optimizer_connector.update_learning_rates(
interval="step",
monitor_metrics=monitor_metrics,
opt_indices=[opt_idx for opt_idx, _ in self.get_active_optimizers()],
)
def update_lr_schedulers(self, interval: str) -> None:
if interval == "step":
finished_accumulation = self._accumulated_batches_reached()
finished_epoch = self._num_training_batches_reached()
if not finished_accumulation and not finished_epoch:
return
self.trainer.optimizer_connector.update_learning_rates(
interval=interval,
opt_indices=[opt_idx for opt_idx, _ in self.get_active_optimizers()],
)

def increment_accumulated_grad_global_step(self):
num_accumulated_batches_reached = self._accumulated_batches_reached()
Expand All @@ -897,15 +894,21 @@ def should_accumulate(self):

def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bool = False) -> bool:
""" Decide if we should run validation. """

if not self.trainer.enable_validation:
return False

# check if this epoch is eligible to run validation
if (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch != 0:
is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
if not is_val_check_epoch:
return False

# val_check_batch is inf for iterable datasets with no length defined
is_infinite_dataset = self.trainer.val_check_batch == float('inf')
if on_epoch and is_last_batch and is_infinite_dataset:
return True

if on_epoch and self.trainer.should_stop:
return True

# TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch
is_val_check_batch = False
if isinstance(self.trainer.limit_train_batches, int) and self.trainer.val_check_batch == float('inf'):
Expand All @@ -915,12 +918,9 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bo

# Note: num_training_batches is also inf for iterable datasets with no length defined
epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0
is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf")

if on_epoch:
return (
is_val_check_batch and epoch_end_val_check
) or self.trainer.should_stop or is_last_batch_for_infinite_dataset
return is_val_check_batch and epoch_end_val_check
else:
return is_val_check_batch and not epoch_end_val_check

Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_early_stopping_patience_train(
"""Test to ensure that early stopping is not triggered before patience is exhausted."""

class ModelOverrideTrainReturn(BoringModel):
train_return_values = torch.Tensor(loss_values)
train_return_values = torch.tensor(loss_values)

def training_epoch_end(self, outputs):
loss = self.train_return_values[self.current_epoch]
Expand Down
49 changes: 12 additions & 37 deletions tests/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,67 +264,42 @@ def test_tensorboard_log_graph_warning_no_example_input_array(tmpdir):


@mock.patch('pytorch_lightning.loggers.TensorBoardLogger.log_metrics')
@pytest.mark.parametrize('expected', [
([5, 11, 17]),
])
def test_tensorboard_with_accummulated_gradients(mock_log_metrics, expected, tmpdir):
"""
Tests to ensure that tensorboard log properly when accumulated_gradients > 1
"""
def test_tensorboard_with_accummulated_gradients(mock_log_metrics, tmpdir):
"""Tests to ensure that tensorboard log properly when accumulated_gradients > 1"""

class TestModel(BoringModel):

def __init__(self):
super().__init__()
self._count = 0
self._indexes = []

def training_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
self.log('count', self._count, on_step=True, on_epoch=True)
self.log('loss', loss, on_step=True, on_epoch=True)
self.indexes = []

def training_step(self, *args):
self.log('foo', 1, on_step=True, on_epoch=True)
if not self.trainer.train_loop.should_accumulate():
if self.trainer.logger_connector.should_update_logs:
self._indexes.append(self.trainer.global_step)

return loss

def validation_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
self.log('val_loss', loss, on_step=True, on_epoch=True)
return loss

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=.001)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
self.indexes.append(self.trainer.global_step)
return super().training_step(*args)

model = TestModel()
model.training_epoch_end = None
model.validation_epoch_end = None

logger_0 = TensorBoardLogger(tmpdir, default_hp_metric=False)

trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=12,
limit_val_batches=0,
max_epochs=3,
gpus=0,
accumulate_grad_batches=2,
logger=[logger_0],
log_every_n_steps=3,
)
trainer.fit(model)

mock_count_epochs = [m[2]["step"] for m in mock_log_metrics.mock_calls if "count_epoch" in m[2]["metrics"]]
assert mock_count_epochs == expected
calls = [m[2] for m in mock_log_metrics.mock_calls]
count_epochs = [c["step"] for c in calls if "foo_epoch" in c["metrics"]]
assert count_epochs == [5, 11, 17]

mock_count_steps = [m[2]["step"] for m in mock_log_metrics.mock_calls if "count_step" in m[2]["metrics"]]
assert model._indexes == mock_count_steps
count_steps = [c["step"] for c in calls if "foo_step" in c["metrics"]]
assert count_steps == model.indexes


@mock.patch('pytorch_lightning.loggers.tensorboard.SummaryWriter')
Expand Down
Loading