Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions pytorch_lightning/trainer/connectors/optimizer_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@ def on_trainer_init(self):
self.trainer.optimizers = []
self.trainer.optimizer_frequencies = []

def update_learning_rates(self, interval: str, monitor_metrics=None):
def update_learning_rates(self, interval: str):
"""Update learning rates.

Args:
interval: either 'epoch' or 'step'.
monitor_metrics: dict of possible values to monitor
"""
if not self.trainer.lr_schedulers:
return
Expand All @@ -44,11 +43,8 @@ def update_learning_rates(self, interval: str, monitor_metrics=None):
monitor_key, monitor_val = None, None
if lr_scheduler['reduce_on_plateau']:
monitor_key = lr_scheduler['monitor']
monitor_val = (
monitor_metrics.get(monitor_key)
if monitor_metrics is not None
else self.trainer.logger_connector.callback_metrics.get(monitor_key)
)
monitor_val = self.trainer.logger_connector.callback_metrics.get(monitor_key)

if monitor_val is None:
if lr_scheduler.get('strict', True):
avail_metrics = self.trainer.logger_connector.callback_metrics.keys()
Expand Down
3 changes: 0 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,9 +526,6 @@ def train(self):
self.train_loop.on_train_end()
return

# update LR schedulers
self.optimizer_connector.update_learning_rates(interval='epoch')

# early stopping
met_min_epochs = epoch >= self.min_epochs - 1
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
Expand Down
52 changes: 36 additions & 16 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,10 @@ def run_training_epoch(self):
# ------------------------------------
batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)

# update LR schedulers
self._update_train_loop_lr_schedulers_step()
self._update_train_loop_lr_schedulers_epoch(is_last_batch, batch_output.signal)

# when returning -1 from train_step, we end epoch early
if batch_output.signal == -1:
break
Expand Down Expand Up @@ -573,17 +577,15 @@ def run_training_epoch(self):
# -----------------------------------------
self.save_loggers_on_train_batch_end()

# update LR schedulers
monitor_metrics = deepcopy(self.trainer.logger_connector.callback_metrics)
self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics)
self.trainer.checkpoint_connector.has_trained = True

# max steps reached, end training
if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1:
accumulation_done = self._accumulated_batches_reached()
# Ensure accumulation across batches has completed before breaking loop
if accumulation_done:
break
# max steps reached and accumulation across batches has completed, end training
if (
self.trainer.max_steps is not None
and self.trainer.max_steps == self.trainer.global_step + 1
and self._accumulated_batches_reached()
):
break

# end epoch early
# stop when the flag is changed or we've gone past the amount
Expand All @@ -594,7 +596,7 @@ def run_training_epoch(self):
self.trainer.total_batch_idx += 1

# stop epoch if we limited the number of training batches
if (batch_idx + 1) >= self.trainer.num_training_batches:
if self._num_training_batches_reached():
break

# progress global step according to grads progress
Expand All @@ -611,7 +613,7 @@ def run_training_epoch(self):
self.num_optimizers
)

# when no val loop is present or fast-dev-run still need to call checkpoints
# when no val loop is present or fast_dev_run still need to call checkpoints
self.check_checkpoint_callback(not (should_check_val or is_overridden('validation_step', model)))

# increment the global step once
Expand Down Expand Up @@ -805,13 +807,26 @@ def backward(self, result, optimizer, opt_idx, *args, **kwargs):
# track gradients
self.track_and_norm_grad(optimizer=optimizer)

def update_train_loop_lr_schedulers(self, monitor_metrics=None):
def _update_train_loop_lr_schedulers_step(self):
num_accumulated_batches_reached = self._accumulated_batches_reached()
num_training_batches_reached = self._num_training_batches_reached()

if num_accumulated_batches_reached or num_training_batches_reached:
# update lr
self.trainer.optimizer_connector.update_learning_rates(interval="step", monitor_metrics=monitor_metrics)
self.trainer.optimizer_connector.update_learning_rates(interval="step")

def _update_train_loop_lr_schedulers_epoch(self, is_last_batch, batch_signal):
num_training_batches_reached = self._num_training_batches_reached()
is_last_batch_for_inf_ds = self._is_last_batch_for_infinite_dataset(is_last_batch)

if (
num_training_batches_reached
or is_last_batch_for_inf_ds
or self.trainer.should_stop
or batch_signal == -1
):
# update lr
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')

def run_on_epoch_end_hook(self, epoch_output):
self.trainer.call_hook('on_epoch_end')
Expand All @@ -831,7 +846,10 @@ def _accumulated_batches_reached(self):
return (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0

def _num_training_batches_reached(self):
return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches
return (self.trainer.batch_idx + 1) >= self.trainer.num_training_batches

def _is_last_batch_for_infinite_dataset(self, is_last_batch):
return is_last_batch and self.trainer.val_check_batch == float("inf")

def should_accumulate(self):
# checks if backward or backward + optimizer step (via closure)
Expand All @@ -845,8 +863,10 @@ def should_check_val_fx(self, batch_idx, is_last_batch):
is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
can_check_val = self.trainer.enable_validation and is_val_check_epoch
should_check_val = is_val_check_batch or self.trainer.should_stop
is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf")
should_check_val = can_check_val and (should_check_val or is_last_batch_for_infinite_dataset)
should_check_val = (
can_check_val
and (should_check_val or self._is_last_batch_for_infinite_dataset(is_last_batch))
)

return should_check_val

Expand Down