Skip to content
8 changes: 8 additions & 0 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ def on_sanity_check_end(self, trainer, pl_module):
"""Called when the validation sanity check ends."""
pass

def on_train_batch_start(self, trainer, pl_module):
"""Called when the validation batch begins."""
pass

def on_train_batch_end(self, trainer, pl_module):
"""Called when the validation batch ends."""
pass

def on_train_epoch_start(self, trainer, pl_module):
"""Called when the train epoch begins."""
pass
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/lr_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def on_train_start(self, trainer, pl_module):
# Initialize for storing values
self.lrs = {name: [] for name in names}

def on_batch_start(self, trainer, pl_module):
def on_train_batch_start(self, trainer, pl_module):
latest_stat = self._extract_lr(trainer, 'step')
if trainer.logger and latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def __init__(self):
def disable(self):
self.enable = False

def on_batch_end(self, trainer, pl_module):
super().on_batch_end(trainer, pl_module) # don't forget this :)
def on_train_batch_end(self, trainer, pl_module):
super().on_train_batch_end(trainer, pl_module) # don't forget this :)
percent = (self.train_batch_idx / self.total_train_batches) * 100
sys.stdout.flush()
sys.stdout.write(f'{percent:.01f} percent complete \r')
Expand Down Expand Up @@ -138,7 +138,7 @@ def on_train_start(self, trainer, pl_module):
def on_epoch_start(self, trainer, pl_module):
self._train_batch_idx = 0

def on_batch_end(self, trainer, pl_module):
def on_train_batch_end(self, trainer, pl_module):
self._train_batch_idx += 1

def on_validation_start(self, trainer, pl_module):
Expand Down Expand Up @@ -318,8 +318,8 @@ def on_epoch_start(self, trainer, pl_module):
self.main_progress_bar.reset(convert_inf(total_batches))
self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch + 1}')

def on_batch_end(self, trainer, pl_module):
super().on_batch_end(trainer, pl_module)
def on_train_batch_end(self, trainer, pl_module):
super().on_train_batch_end(trainer, pl_module)
if self.is_enabled and self.train_batch_idx % self.refresh_rate == 0:
self.main_progress_bar.update(self.refresh_rate)
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)
Expand Down
21 changes: 21 additions & 0 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,23 @@ def on_train_end(self) -> None:
"""
# do something at the end of training

def on_train_batch_start(self, batch: Any) -> None:
"""
Called in the training loop before anything happens for that batch.

If you return -1 here, you will skip training for the rest of the current epoch.

Args:
batch: The batched data as it is returned by the training DataLoader.
"""
# do something when the batch starts

def on_train_batch_end(self) -> None:
"""
Called in the training loop after the batch.
"""
# do something when the batch end

def on_batch_start(self, batch: Any) -> None:
"""
Called in the training loop before anything happens for that batch.
Expand All @@ -85,12 +102,16 @@ def on_batch_start(self, batch: Any) -> None:

Args:
batch: The batched data as it is returned by the training DataLoader.

.. warning:: Deprecated in 0.9.0 will remove 1.0.0 (use `on_train_batch_start` instead)
"""
# do something when the batch starts

def on_batch_end(self) -> None:
"""
Called in the training loop after the batch.

.. warning:: Deprecated in 0.9.0 will remove 1.0.0 (use `on_train_batch_end` instead)
"""
# do something when the batch ends

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1771,7 +1771,7 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg
elif self.example_input_array is not None:
input_data = self.example_input_array
else:
raise ValueError(f'input_sample and example_input_array tensors are both missing.')
raise ValueError('input_sample and example_input_array tensors are both missing.')

if 'example_outputs' not in kwargs:
self.eval()
Expand Down
12 changes: 11 additions & 1 deletion pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class TrainerCallbackHookMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
callbacks: List[Callback] = []
get_model: Callable = ...
get_model: Callable

def setup(self, stage: str):
"""Called in the beginning of fit and test"""
Expand Down Expand Up @@ -111,6 +111,16 @@ def on_batch_end(self):
for callback in self.callbacks:
callback.on_batch_end(self, self.get_model())

def on_train_batch_start(self):
"""Called when the training batch begins."""
for callback in self.callbacks:
callback.on_train_batch_start(self, self.get_model())

def on_train_batch_end(self):
"""Called when the training batch ends."""
for callback in self.callbacks:
callback.on_train_batch_end(self, self.get_model())

def on_validation_batch_start(self):
"""Called when the validation batch begins."""
for callback in self.callbacks:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def on_batch_start(self, trainer, pl_module):

self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0])

def on_batch_end(self, trainer, pl_module):
def on_train_batch_end(self, trainer, pl_module):
""" Called when the training batch ends, logs the calculated loss """
if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
return
Expand Down
19 changes: 19 additions & 0 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ class TrainerTrainLoopMixin(ABC):
on_train_end: Callable
on_batch_start: Callable
on_batch_end: Callable
on_train_batch_start: Callable
on_train_batch_end: Callable
on_epoch_start: Callable
on_epoch_end: Callable
on_validation_end: Callable
Expand Down Expand Up @@ -690,6 +692,7 @@ def run_training_batch(self, batch, batch_idx):
return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic)

# Batch start events
# TODO: deprecate 1.0
with self.profiler.profile('on_batch_start'):
# callbacks
self.on_batch_start()
Expand All @@ -699,6 +702,15 @@ def run_training_batch(self, batch, batch_idx):
if response == -1:
return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic)

with self.profiler.profile('on_train_batch_start'):
# callbacks
self.on_train_batch_start()
# hooks
if self.is_function_implemented('on_train_batch_start'):
response = self.get_model().on_train_batch_start(batch)
if response == -1:
return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic)

splits = [batch]
if self.truncated_bptt_steps is not None:
model_ref = self.get_model()
Expand Down Expand Up @@ -785,6 +797,13 @@ def run_training_batch(self, batch, batch_idx):
if self.is_function_implemented('on_batch_end'):
self.get_model().on_batch_end()

with self.profiler.profile('on_train_batch_end'):
# callbacks
self.on_train_batch_end()
# model hooks
if self.is_function_implemented('on_train_batch_end'):
self.get_model().on_train_batch_end()

# collapse all metrics into one dict
batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()}

Expand Down
16 changes: 16 additions & 0 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def __init__(self):
self.on_epoch_end_called = False
self.on_batch_start_called = False
self.on_batch_end_called = False
self.on_train_batch_start_called = False
self.on_train_batch_end_called = False
self.on_validation_batch_start_called = False
self.on_validation_batch_end_called = False
self.on_test_batch_start_called = False
Expand Down Expand Up @@ -87,6 +89,14 @@ def on_batch_end(self, trainer, pl_module):
_check_args(trainer, pl_module)
self.on_batch_end_called = True

def on_train_batch_start(self, trainer, pl_module):
_check_args(trainer, pl_module)
self.on_train_batch_start_called = True

def on_train_batch_end(self, trainer, pl_module):
_check_args(trainer, pl_module)
self.on_train_batch_end_called = True

def on_validation_batch_start(self, trainer, pl_module):
_check_args(trainer, pl_module)
self.on_validation_batch_start_called = True
Expand Down Expand Up @@ -150,6 +160,8 @@ def on_test_end(self, trainer, pl_module):
assert not test_callback.on_epoch_start_called
assert not test_callback.on_batch_start_called
assert not test_callback.on_batch_end_called
assert not test_callback.on_train_batch_start_called
assert not test_callback.on_train_batch_end_called
assert not test_callback.on_validation_batch_start_called
assert not test_callback.on_validation_batch_end_called
assert not test_callback.on_test_batch_start_called
Expand Down Expand Up @@ -177,6 +189,8 @@ def on_test_end(self, trainer, pl_module):
assert not test_callback.on_epoch_start_called
assert not test_callback.on_batch_start_called
assert not test_callback.on_batch_end_called
assert not test_callback.on_train_batch_start_called
assert not test_callback.on_train_batch_end_called
assert not test_callback.on_validation_batch_start_called
assert not test_callback.on_validation_batch_end_called
assert not test_callback.on_test_batch_start_called
Expand All @@ -202,6 +216,8 @@ def on_test_end(self, trainer, pl_module):
assert test_callback.on_epoch_start_called
assert test_callback.on_batch_start_called
assert test_callback.on_batch_end_called
assert test_callback.on_train_batch_start_called
assert test_callback.on_train_batch_end_called
assert test_callback.on_validation_batch_start_called
assert test_callback.on_validation_batch_end_called
assert test_callback.on_train_start_called
Expand Down
8 changes: 4 additions & 4 deletions tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,12 @@ class CurrentProgressBar(ProgressBar):
val_batches_seen = 0
test_batches_seen = 0

def on_batch_start(self, trainer, pl_module):
super().on_batch_start(trainer, pl_module)
def on_train_batch_start(self, trainer, pl_module):
super().on_train_batch_start(trainer, pl_module)
assert self.train_batch_idx == trainer.batch_idx

def on_batch_end(self, trainer, pl_module):
super().on_batch_end(trainer, pl_module)
def on_train_batch_end(self, trainer, pl_module):
super().on_train_batch_end(trainer, pl_module)
assert self.train_batch_idx == trainer.batch_idx + 1
if not self.is_disabled and self.train_batch_idx % self.refresh_rate == 0:
assert self.main_progress_bar.n == self.train_batch_idx
Expand Down
6 changes: 3 additions & 3 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,17 @@ def test_can_prepare_data(tmpdir):

# is_overridden prepare data = True
# has been called
# False
# False
dm._has_prepared_data = True
assert not trainer.can_prepare_data()

# has not been called
# True
# True
dm._has_prepared_data = False
assert trainer.can_prepare_data()

# is_overridden prepare data = False
# True
# True
dm.prepare_data = None
assert trainer.can_prepare_data()

Expand Down
2 changes: 1 addition & 1 deletion tests/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ class RankZeroLoggerCheck(Callback):
# this class has to be defined outside the test function, otherwise we get pickle error
# due to the way ddp process is launched

def on_batch_start(self, trainer, pl_module):
def on_train_batch_start(self, trainer, pl_module):
is_dummy = isinstance(trainer.logger.experiment, DummyExperiment)
if trainer.is_global_zero:
assert not is_dummy
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def increment_on_load_checkpoint(self, _):
# Bind methods to keep track of epoch numbers, batch numbers it has seen
# as well as number of times it has called on_load_checkpoint()
model.on_epoch_end = types.MethodType(increment_epoch, model)
model.on_batch_start = types.MethodType(increment_batch, model)
model.on_train_batch_start = types.MethodType(increment_batch, model)
model.on_load_checkpoint = types.MethodType(increment_on_load_checkpoint, model)
return model

Expand Down Expand Up @@ -691,7 +691,7 @@ class InterruptCallback(Callback):
def __init__(self):
super().__init__()

def on_batch_start(self, trainer, pl_module):
def on_train_batch_start(self, trainer, pl_module):
raise KeyboardInterrupt

class HandleInterruptCallback(Callback):
Expand Down
2 changes: 1 addition & 1 deletion tests/utilities/test_dtype_device_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, *args, **kwargs):

class DeviceAssertCallback(Callback):

def on_batch_start(self, trainer, model):
def on_train_batch_start(self, trainer, model):
rank = trainer.local_rank
assert isinstance(model, TopModule)
# index = None also means first device
Expand Down