From 4dcc0ad0976d31d013df7f397c9b251bb2dd32e2 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 5 Aug 2020 18:24:13 -0400 Subject: [PATCH 01/13] modified hook --- pytorch_lightning/callbacks/base.py | 16 ++++++++++++++++ pytorch_lightning/callbacks/progress.py | 6 +++--- pytorch_lightning/core/hooks.py | 21 +++++++++++++++++++++ pytorch_lightning/trainer/callback_hook.py | 10 ++++++++++ pytorch_lightning/trainer/lr_finder.py | 2 +- pytorch_lightning/trainer/training_loop.py | 19 +++++++++++++++++++ tests/callbacks/test_progress_bar.py | 2 +- 7 files changed, 71 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index a9c6e1fb520cb..5807e191dcdca 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -46,6 +46,14 @@ def on_sanity_check_end(self, trainer, pl_module): """Called when the validation sanity check ends.""" pass + def on_train_batch_start(self, trainer, pl_module): + """Called when the validation batch begins.""" + pass + + def on_train_batch_end(self, trainer, pl_module): + """Called when the validation batch ends.""" + pass + def on_train_epoch_start(self, trainer, pl_module): """Called when the train epoch begins.""" pass @@ -82,6 +90,14 @@ def on_batch_start(self, trainer, pl_module): """Called when the training batch begins.""" pass + def on_train_batch_start(self, trainer, pl_module): + """Called when the validation batch begins.""" + pass + + def on_train_batch_end(self, trainer, pl_module): + """Called when the validation batch ends.""" + pass + def on_validation_batch_start(self, trainer, pl_module): """Called when the validation batch begins.""" pass diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 0acdbcc7509ea..776fc3f958f8b 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -36,7 +36,7 @@ def __init__(self): def disable(self): self.enable = False - def on_batch_end(self, trainer, pl_module): + def on_train_batch_end(self, trainer, pl_module): super().on_batch_end(trainer, pl_module) # don't forget this :) percent = (self.train_batch_idx / self.total_train_batches) * 100 sys.stdout.flush() @@ -138,7 +138,7 @@ def on_train_start(self, trainer, pl_module): def on_epoch_start(self, trainer, pl_module): self._train_batch_idx = 0 - def on_batch_end(self, trainer, pl_module): + def on_train_batch_end(self, trainer, pl_module): self._train_batch_idx += 1 def on_validation_start(self, trainer, pl_module): @@ -318,7 +318,7 @@ def on_epoch_start(self, trainer, pl_module): self.main_progress_bar.reset(convert_inf(total_batches)) self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch + 1}') - def on_batch_end(self, trainer, pl_module): + def on_train_batch_end(self, trainer, pl_module): super().on_batch_end(trainer, pl_module) if self.is_enabled and self.train_batch_idx % self.refresh_rate == 0: self.main_progress_bar.update(self.refresh_rate) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 8c6b726ac31d2..1218dcbe6760f 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -77,6 +77,23 @@ def on_train_end(self) -> None: """ # do something at the end of training + def on_train_batch_start(self, batch: Any) -> None: + """ + Called in the training loop before anything happens for that batch. + + If you return -1 here, you will skip training for the rest of the current epoch. + + Args: + batch: The batched data as it is returned by the training DataLoader. + """ + # do something when the batch starts + + def on_train_batch_end(self) -> None: + """ + Called in the training loop after the batch. + """ + # do something when the batch end + def on_batch_start(self, batch: Any) -> None: """ Called in the training loop before anything happens for that batch. @@ -85,12 +102,16 @@ def on_batch_start(self, batch: Any) -> None: Args: batch: The batched data as it is returned by the training DataLoader. + + .. warning:: Deprecated in 0.9.0 will remove 1.0.0 (use `on_train_batch_start` instead) """ # do something when the batch starts def on_batch_end(self) -> None: """ Called in the training loop after the batch. + + .. warning:: Deprecated in 0.9.0 will remove 1.0.0 (use `on_train_batch_end` instead) """ # do something when the batch ends diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 89b5e712c9190..e703ea33b3742 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -111,6 +111,16 @@ def on_batch_end(self): for callback in self.callbacks: callback.on_batch_end(self, self.get_model()) + def on_train_batch_start(self): + """Called when the training batch begins.""" + for callback in self.callbacks: + callback.on_train_batch_start(self, self.get_model()) + + def on_train_batch_end(self): + """Called when the training batch ends.""" + for callback in self.callbacks: + callback.on_train_batch_end(self, self.get_model()) + def on_validation_batch_start(self): """Called when the validation batch begins.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/lr_finder.py b/pytorch_lightning/trainer/lr_finder.py index 3b2778d24071c..23ad702956e84 100755 --- a/pytorch_lightning/trainer/lr_finder.py +++ b/pytorch_lightning/trainer/lr_finder.py @@ -382,7 +382,7 @@ def on_batch_start(self, trainer, pl_module): self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0]) - def on_batch_end(self, trainer, pl_module): + def on_train_batch_end(self, trainer, pl_module): """ Called when the training batch ends, logs the calculated loss """ if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0: return diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index c8cb81ed090b1..993e8ccd53fd0 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -263,6 +263,8 @@ class TrainerTrainLoopMixin(ABC): on_train_end: Callable on_batch_start: Callable on_batch_end: Callable + on_train_batch_start: Callable + on_train_batch_end: Callable on_epoch_start: Callable on_epoch_end: Callable on_validation_end: Callable @@ -690,6 +692,7 @@ def run_training_batch(self, batch, batch_idx): return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic) # Batch start events + # TODO: deprecate 1.0 with self.profiler.profile('on_batch_start'): # callbacks self.on_batch_start() @@ -699,6 +702,15 @@ def run_training_batch(self, batch, batch_idx): if response == -1: return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) + with self.profiler.profile('on_train_batch_start'): + # callbacks + self.on_train_batch_start() + # hooks + if self.is_function_implemented('on_train_batch_start'): + response = self.get_model().on_train_batch_start(batch) + if response == -1: + return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) + splits = [batch] if self.truncated_bptt_steps is not None: model_ref = self.get_model() @@ -785,6 +797,13 @@ def run_training_batch(self, batch, batch_idx): if self.is_function_implemented('on_batch_end'): self.get_model().on_batch_end() + with self.profiler.profile('on_train_batch_end'): + # callbacks + self.on_train_batch_end() + # model hooks + if self.is_function_implemented('on_train_batch_end'): + self.get_model().on_train_batch_end() + # collapse all metrics into one dict batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()} diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 23743dc5dcb2c..c5c381ac46386 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -157,7 +157,7 @@ def on_batch_start(self, trainer, pl_module): super().on_batch_start(trainer, pl_module) assert self.train_batch_idx == trainer.batch_idx - def on_batch_end(self, trainer, pl_module): + def on_train_batch_end(self, trainer, pl_module): super().on_batch_end(trainer, pl_module) assert self.train_batch_idx == trainer.batch_idx + 1 if not self.is_disabled and self.train_batch_idx % self.refresh_rate == 0: From f0126791c7ea272a0ed253254f1e8b44f401007e Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 5 Aug 2020 18:27:00 -0400 Subject: [PATCH 02/13] modified hook --- tests/callbacks/test_callbacks.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index d10965524394b..83de82c71de67 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -28,6 +28,8 @@ def __init__(self): self.on_epoch_end_called = False self.on_batch_start_called = False self.on_batch_end_called = False + self.on_train_batch_start_called = False + self.on_train_batch_end_called = False self.on_validation_batch_start_called = False self.on_validation_batch_end_called = False self.on_test_batch_start_called = False @@ -87,6 +89,14 @@ def on_batch_end(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_batch_end_called = True + def on_train_batch_start(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_train_batch_start_called = True + + def on_train_batch_end(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_train_batch_end_called = True + def on_validation_batch_start(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_validation_batch_start_called = True @@ -150,6 +160,8 @@ def on_test_end(self, trainer, pl_module): assert not test_callback.on_epoch_start_called assert not test_callback.on_batch_start_called assert not test_callback.on_batch_end_called + assert not test_callback.on_train_batch_start_called + assert not test_callback.on_train_batch_end_called assert not test_callback.on_validation_batch_start_called assert not test_callback.on_validation_batch_end_called assert not test_callback.on_test_batch_start_called @@ -177,6 +189,8 @@ def on_test_end(self, trainer, pl_module): assert not test_callback.on_epoch_start_called assert not test_callback.on_batch_start_called assert not test_callback.on_batch_end_called + assert not test_callback.on_train_batch_start_called + assert not test_callback.on_train_batch_end_called assert not test_callback.on_validation_batch_start_called assert not test_callback.on_validation_batch_end_called assert not test_callback.on_test_batch_start_called @@ -202,6 +216,8 @@ def on_test_end(self, trainer, pl_module): assert test_callback.on_epoch_start_called assert test_callback.on_batch_start_called assert test_callback.on_batch_end_called + assert test_callback.on_train_batch_start_called + assert test_callback.on_train_batch_end_called assert test_callback.on_validation_batch_start_called assert test_callback.on_validation_batch_end_called assert test_callback.on_train_start_called From 375c037bab6fc1181b3bda62e260cd974df36404 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 5 Aug 2020 18:28:31 -0400 Subject: [PATCH 03/13] modified hook --- pytorch_lightning/callbacks/base.py | 8 -------- pytorch_lightning/core/lightning.py | 2 +- tests/core/test_datamodules.py | 6 +++--- 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 5807e191dcdca..82a0e6b0436a6 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -90,14 +90,6 @@ def on_batch_start(self, trainer, pl_module): """Called when the training batch begins.""" pass - def on_train_batch_start(self, trainer, pl_module): - """Called when the validation batch begins.""" - pass - - def on_train_batch_end(self, trainer, pl_module): - """Called when the validation batch ends.""" - pass - def on_validation_batch_start(self, trainer, pl_module): """Called when the validation batch begins.""" pass diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 80081c0dd446f..79a97082796b5 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1771,7 +1771,7 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg elif self.example_input_array is not None: input_data = self.example_input_array else: - raise ValueError(f'input_sample and example_input_array tensors are both missing.') + raise ValueError('input_sample and example_input_array tensors are both missing.') if 'example_outputs' not in kwargs: self.eval() diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index ec66afb71ca22..305f7f3d69150 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -50,17 +50,17 @@ def test_can_prepare_data(tmpdir): # is_overridden prepare data = True # has been called - # False + # False dm._has_prepared_data = True assert not trainer.can_prepare_data() # has not been called - # True + # True dm._has_prepared_data = False assert trainer.can_prepare_data() # is_overridden prepare data = False - # True + # True dm.prepare_data = None assert trainer.can_prepare_data() From 51cfe130777daf96f78a5644f9bb7b1edb7303b5 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 5 Aug 2020 18:31:45 -0400 Subject: [PATCH 04/13] modified hook --- pytorch_lightning/callbacks/progress.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 776fc3f958f8b..4ab990f74724e 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -37,7 +37,7 @@ def disable(self): self.enable = False def on_train_batch_end(self, trainer, pl_module): - super().on_batch_end(trainer, pl_module) # don't forget this :) + super().on_train_batch_end(trainer, pl_module) # don't forget this :) percent = (self.train_batch_idx / self.total_train_batches) * 100 sys.stdout.flush() sys.stdout.write(f'{percent:.01f} percent complete \r') @@ -319,7 +319,7 @@ def on_epoch_start(self, trainer, pl_module): self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch + 1}') def on_train_batch_end(self, trainer, pl_module): - super().on_batch_end(trainer, pl_module) + super().on_train_batch_end(trainer, pl_module) if self.is_enabled and self.train_batch_idx % self.refresh_rate == 0: self.main_progress_bar.update(self.refresh_rate) self.main_progress_bar.set_postfix(trainer.progress_bar_dict) From 6ab5bd4f5f6a815c67ea6ef1ddbb614dafc74a7e Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 5 Aug 2020 18:33:03 -0400 Subject: [PATCH 05/13] modified hook --- tests/callbacks/test_progress_bar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index c5c381ac46386..e123fe8241d8a 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -158,7 +158,7 @@ def on_batch_start(self, trainer, pl_module): assert self.train_batch_idx == trainer.batch_idx def on_train_batch_end(self, trainer, pl_module): - super().on_batch_end(trainer, pl_module) + super().on_train_batch_end(trainer, pl_module) assert self.train_batch_idx == trainer.batch_idx + 1 if not self.is_disabled and self.train_batch_idx % self.refresh_rate == 0: assert self.main_progress_bar.n == self.train_batch_idx From 5e23f7fabeea5c3672273482ec9e425ff34cf414 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 5 Aug 2020 18:34:43 -0400 Subject: [PATCH 06/13] modified hook --- pytorch_lightning/callbacks/lr_logger.py | 2 +- tests/callbacks/test_progress_bar.py | 4 ++-- tests/loggers/test_all.py | 2 +- tests/trainer/test_trainer.py | 4 ++-- tests/utilities/test_dtype_device_mixin.py | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/callbacks/lr_logger.py b/pytorch_lightning/callbacks/lr_logger.py index 87953d496b3ad..7ec73b8c88811 100755 --- a/pytorch_lightning/callbacks/lr_logger.py +++ b/pytorch_lightning/callbacks/lr_logger.py @@ -64,7 +64,7 @@ def on_train_start(self, trainer, pl_module): # Initialize for storing values self.lrs = {name: [] for name in names} - def on_batch_start(self, trainer, pl_module): + def on_train_batch_start(self, trainer, pl_module): latest_stat = self._extract_lr(trainer, 'step') if trainer.logger and latest_stat: trainer.logger.log_metrics(latest_stat, step=trainer.global_step) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index e123fe8241d8a..779077c437585 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -153,8 +153,8 @@ class CurrentProgressBar(ProgressBar): val_batches_seen = 0 test_batches_seen = 0 - def on_batch_start(self, trainer, pl_module): - super().on_batch_start(trainer, pl_module) + def on_train_batch_start(self, trainer, pl_module): + super().on_train_batch_start(trainer, pl_module) assert self.train_batch_idx == trainer.batch_idx def on_train_batch_end(self, trainer, pl_module): diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index 3afa1dd11c56c..5bd81d7116948 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -214,7 +214,7 @@ class RankZeroLoggerCheck(Callback): # this class has to be defined outside the test function, otherwise we get pickle error # due to the way ddp process is launched - def on_batch_start(self, trainer, pl_module): + def on_train_batch_start(self, trainer, pl_module): is_dummy = isinstance(trainer.logger.experiment, DummyExperiment) if trainer.is_global_zero: assert not is_dummy diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index c7652ebecf3f9..3dbb7b7c079d6 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -377,7 +377,7 @@ def increment_on_load_checkpoint(self, _): # Bind methods to keep track of epoch numbers, batch numbers it has seen # as well as number of times it has called on_load_checkpoint() model.on_epoch_end = types.MethodType(increment_epoch, model) - model.on_batch_start = types.MethodType(increment_batch, model) + model.on_train_batch_start = types.MethodType(increment_batch, model) model.on_load_checkpoint = types.MethodType(increment_on_load_checkpoint, model) return model @@ -691,7 +691,7 @@ class InterruptCallback(Callback): def __init__(self): super().__init__() - def on_batch_start(self, trainer, pl_module): + def on_train_batch_start(self, trainer, pl_module): raise KeyboardInterrupt class HandleInterruptCallback(Callback): diff --git a/tests/utilities/test_dtype_device_mixin.py b/tests/utilities/test_dtype_device_mixin.py index f755cf5c634ed..08f808bda9ceb 100644 --- a/tests/utilities/test_dtype_device_mixin.py +++ b/tests/utilities/test_dtype_device_mixin.py @@ -27,7 +27,7 @@ def __init__(self, *args, **kwargs): class DeviceAssertCallback(Callback): - def on_batch_start(self, trainer, model): + def on_train_batch_start(self, trainer, model): rank = trainer.local_rank assert isinstance(model, TopModule) # index = None also means first device From ef2f932d8ef3171a7a176de44301268cf2b83be0 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 5 Aug 2020 18:24:13 -0400 Subject: [PATCH 07/13] modified hook --- pytorch_lightning/callbacks/base.py | 16 ++++++++++++++++ pytorch_lightning/callbacks/progress.py | 6 +++--- pytorch_lightning/core/hooks.py | 21 +++++++++++++++++++++ pytorch_lightning/trainer/callback_hook.py | 10 ++++++++++ pytorch_lightning/trainer/lr_finder.py | 2 +- pytorch_lightning/trainer/training_loop.py | 19 +++++++++++++++++++ tests/callbacks/test_progress_bar.py | 2 +- 7 files changed, 71 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index a9c6e1fb520cb..5807e191dcdca 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -46,6 +46,14 @@ def on_sanity_check_end(self, trainer, pl_module): """Called when the validation sanity check ends.""" pass + def on_train_batch_start(self, trainer, pl_module): + """Called when the validation batch begins.""" + pass + + def on_train_batch_end(self, trainer, pl_module): + """Called when the validation batch ends.""" + pass + def on_train_epoch_start(self, trainer, pl_module): """Called when the train epoch begins.""" pass @@ -82,6 +90,14 @@ def on_batch_start(self, trainer, pl_module): """Called when the training batch begins.""" pass + def on_train_batch_start(self, trainer, pl_module): + """Called when the validation batch begins.""" + pass + + def on_train_batch_end(self, trainer, pl_module): + """Called when the validation batch ends.""" + pass + def on_validation_batch_start(self, trainer, pl_module): """Called when the validation batch begins.""" pass diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 0acdbcc7509ea..776fc3f958f8b 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -36,7 +36,7 @@ def __init__(self): def disable(self): self.enable = False - def on_batch_end(self, trainer, pl_module): + def on_train_batch_end(self, trainer, pl_module): super().on_batch_end(trainer, pl_module) # don't forget this :) percent = (self.train_batch_idx / self.total_train_batches) * 100 sys.stdout.flush() @@ -138,7 +138,7 @@ def on_train_start(self, trainer, pl_module): def on_epoch_start(self, trainer, pl_module): self._train_batch_idx = 0 - def on_batch_end(self, trainer, pl_module): + def on_train_batch_end(self, trainer, pl_module): self._train_batch_idx += 1 def on_validation_start(self, trainer, pl_module): @@ -318,7 +318,7 @@ def on_epoch_start(self, trainer, pl_module): self.main_progress_bar.reset(convert_inf(total_batches)) self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch + 1}') - def on_batch_end(self, trainer, pl_module): + def on_train_batch_end(self, trainer, pl_module): super().on_batch_end(trainer, pl_module) if self.is_enabled and self.train_batch_idx % self.refresh_rate == 0: self.main_progress_bar.update(self.refresh_rate) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 8c6b726ac31d2..1218dcbe6760f 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -77,6 +77,23 @@ def on_train_end(self) -> None: """ # do something at the end of training + def on_train_batch_start(self, batch: Any) -> None: + """ + Called in the training loop before anything happens for that batch. + + If you return -1 here, you will skip training for the rest of the current epoch. + + Args: + batch: The batched data as it is returned by the training DataLoader. + """ + # do something when the batch starts + + def on_train_batch_end(self) -> None: + """ + Called in the training loop after the batch. + """ + # do something when the batch end + def on_batch_start(self, batch: Any) -> None: """ Called in the training loop before anything happens for that batch. @@ -85,12 +102,16 @@ def on_batch_start(self, batch: Any) -> None: Args: batch: The batched data as it is returned by the training DataLoader. + + .. warning:: Deprecated in 0.9.0 will remove 1.0.0 (use `on_train_batch_start` instead) """ # do something when the batch starts def on_batch_end(self) -> None: """ Called in the training loop after the batch. + + .. warning:: Deprecated in 0.9.0 will remove 1.0.0 (use `on_train_batch_end` instead) """ # do something when the batch ends diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 89b5e712c9190..e703ea33b3742 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -111,6 +111,16 @@ def on_batch_end(self): for callback in self.callbacks: callback.on_batch_end(self, self.get_model()) + def on_train_batch_start(self): + """Called when the training batch begins.""" + for callback in self.callbacks: + callback.on_train_batch_start(self, self.get_model()) + + def on_train_batch_end(self): + """Called when the training batch ends.""" + for callback in self.callbacks: + callback.on_train_batch_end(self, self.get_model()) + def on_validation_batch_start(self): """Called when the validation batch begins.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/lr_finder.py b/pytorch_lightning/trainer/lr_finder.py index 3b2778d24071c..23ad702956e84 100755 --- a/pytorch_lightning/trainer/lr_finder.py +++ b/pytorch_lightning/trainer/lr_finder.py @@ -382,7 +382,7 @@ def on_batch_start(self, trainer, pl_module): self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0]) - def on_batch_end(self, trainer, pl_module): + def on_train_batch_end(self, trainer, pl_module): """ Called when the training batch ends, logs the calculated loss """ if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0: return diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index c8cb81ed090b1..993e8ccd53fd0 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -263,6 +263,8 @@ class TrainerTrainLoopMixin(ABC): on_train_end: Callable on_batch_start: Callable on_batch_end: Callable + on_train_batch_start: Callable + on_train_batch_end: Callable on_epoch_start: Callable on_epoch_end: Callable on_validation_end: Callable @@ -690,6 +692,7 @@ def run_training_batch(self, batch, batch_idx): return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic) # Batch start events + # TODO: deprecate 1.0 with self.profiler.profile('on_batch_start'): # callbacks self.on_batch_start() @@ -699,6 +702,15 @@ def run_training_batch(self, batch, batch_idx): if response == -1: return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) + with self.profiler.profile('on_train_batch_start'): + # callbacks + self.on_train_batch_start() + # hooks + if self.is_function_implemented('on_train_batch_start'): + response = self.get_model().on_train_batch_start(batch) + if response == -1: + return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) + splits = [batch] if self.truncated_bptt_steps is not None: model_ref = self.get_model() @@ -785,6 +797,13 @@ def run_training_batch(self, batch, batch_idx): if self.is_function_implemented('on_batch_end'): self.get_model().on_batch_end() + with self.profiler.profile('on_train_batch_end'): + # callbacks + self.on_train_batch_end() + # model hooks + if self.is_function_implemented('on_train_batch_end'): + self.get_model().on_train_batch_end() + # collapse all metrics into one dict batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()} diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 23743dc5dcb2c..c5c381ac46386 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -157,7 +157,7 @@ def on_batch_start(self, trainer, pl_module): super().on_batch_start(trainer, pl_module) assert self.train_batch_idx == trainer.batch_idx - def on_batch_end(self, trainer, pl_module): + def on_train_batch_end(self, trainer, pl_module): super().on_batch_end(trainer, pl_module) assert self.train_batch_idx == trainer.batch_idx + 1 if not self.is_disabled and self.train_batch_idx % self.refresh_rate == 0: From 3b6a0c4fe947bb58550128e99245feba39cf1a6e Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 5 Aug 2020 18:27:00 -0400 Subject: [PATCH 08/13] modified hook --- tests/callbacks/test_callbacks.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index d10965524394b..83de82c71de67 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -28,6 +28,8 @@ def __init__(self): self.on_epoch_end_called = False self.on_batch_start_called = False self.on_batch_end_called = False + self.on_train_batch_start_called = False + self.on_train_batch_end_called = False self.on_validation_batch_start_called = False self.on_validation_batch_end_called = False self.on_test_batch_start_called = False @@ -87,6 +89,14 @@ def on_batch_end(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_batch_end_called = True + def on_train_batch_start(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_train_batch_start_called = True + + def on_train_batch_end(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_train_batch_end_called = True + def on_validation_batch_start(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_validation_batch_start_called = True @@ -150,6 +160,8 @@ def on_test_end(self, trainer, pl_module): assert not test_callback.on_epoch_start_called assert not test_callback.on_batch_start_called assert not test_callback.on_batch_end_called + assert not test_callback.on_train_batch_start_called + assert not test_callback.on_train_batch_end_called assert not test_callback.on_validation_batch_start_called assert not test_callback.on_validation_batch_end_called assert not test_callback.on_test_batch_start_called @@ -177,6 +189,8 @@ def on_test_end(self, trainer, pl_module): assert not test_callback.on_epoch_start_called assert not test_callback.on_batch_start_called assert not test_callback.on_batch_end_called + assert not test_callback.on_train_batch_start_called + assert not test_callback.on_train_batch_end_called assert not test_callback.on_validation_batch_start_called assert not test_callback.on_validation_batch_end_called assert not test_callback.on_test_batch_start_called @@ -202,6 +216,8 @@ def on_test_end(self, trainer, pl_module): assert test_callback.on_epoch_start_called assert test_callback.on_batch_start_called assert test_callback.on_batch_end_called + assert test_callback.on_train_batch_start_called + assert test_callback.on_train_batch_end_called assert test_callback.on_validation_batch_start_called assert test_callback.on_validation_batch_end_called assert test_callback.on_train_start_called From 1d08334c6d57151483cd03ae4caeb949ff3be10f Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 5 Aug 2020 18:28:31 -0400 Subject: [PATCH 09/13] modified hook --- pytorch_lightning/callbacks/base.py | 8 -------- pytorch_lightning/core/lightning.py | 2 +- tests/core/test_datamodules.py | 6 +++--- 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 5807e191dcdca..82a0e6b0436a6 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -90,14 +90,6 @@ def on_batch_start(self, trainer, pl_module): """Called when the training batch begins.""" pass - def on_train_batch_start(self, trainer, pl_module): - """Called when the validation batch begins.""" - pass - - def on_train_batch_end(self, trainer, pl_module): - """Called when the validation batch ends.""" - pass - def on_validation_batch_start(self, trainer, pl_module): """Called when the validation batch begins.""" pass diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index d272c23fd9a65..f816726ddf1e1 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1771,7 +1771,7 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg elif self.example_input_array is not None: input_data = self.example_input_array else: - raise ValueError(f'input_sample and example_input_array tensors are both missing.') + raise ValueError('input_sample and example_input_array tensors are both missing.') if 'example_outputs' not in kwargs: self.eval() diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index ec66afb71ca22..305f7f3d69150 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -50,17 +50,17 @@ def test_can_prepare_data(tmpdir): # is_overridden prepare data = True # has been called - # False + # False dm._has_prepared_data = True assert not trainer.can_prepare_data() # has not been called - # True + # True dm._has_prepared_data = False assert trainer.can_prepare_data() # is_overridden prepare data = False - # True + # True dm.prepare_data = None assert trainer.can_prepare_data() From d8ef0a287b0f72bd47d14818b560ab99696d6278 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 5 Aug 2020 18:31:45 -0400 Subject: [PATCH 10/13] modified hook --- pytorch_lightning/callbacks/progress.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 776fc3f958f8b..4ab990f74724e 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -37,7 +37,7 @@ def disable(self): self.enable = False def on_train_batch_end(self, trainer, pl_module): - super().on_batch_end(trainer, pl_module) # don't forget this :) + super().on_train_batch_end(trainer, pl_module) # don't forget this :) percent = (self.train_batch_idx / self.total_train_batches) * 100 sys.stdout.flush() sys.stdout.write(f'{percent:.01f} percent complete \r') @@ -319,7 +319,7 @@ def on_epoch_start(self, trainer, pl_module): self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch + 1}') def on_train_batch_end(self, trainer, pl_module): - super().on_batch_end(trainer, pl_module) + super().on_train_batch_end(trainer, pl_module) if self.is_enabled and self.train_batch_idx % self.refresh_rate == 0: self.main_progress_bar.update(self.refresh_rate) self.main_progress_bar.set_postfix(trainer.progress_bar_dict) From 97a300f1c41fd7555ea93b6a0c0da603f0ab5e13 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 5 Aug 2020 18:33:03 -0400 Subject: [PATCH 11/13] modified hook --- tests/callbacks/test_progress_bar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index c5c381ac46386..e123fe8241d8a 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -158,7 +158,7 @@ def on_batch_start(self, trainer, pl_module): assert self.train_batch_idx == trainer.batch_idx def on_train_batch_end(self, trainer, pl_module): - super().on_batch_end(trainer, pl_module) + super().on_train_batch_end(trainer, pl_module) assert self.train_batch_idx == trainer.batch_idx + 1 if not self.is_disabled and self.train_batch_idx % self.refresh_rate == 0: assert self.main_progress_bar.n == self.train_batch_idx From dec272c65d1ec26eaa51ae24987a4c9ee95849e0 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 5 Aug 2020 18:34:43 -0400 Subject: [PATCH 12/13] modified hook --- pytorch_lightning/callbacks/lr_logger.py | 2 +- tests/callbacks/test_progress_bar.py | 4 ++-- tests/loggers/test_all.py | 2 +- tests/trainer/test_trainer.py | 4 ++-- tests/utilities/test_dtype_device_mixin.py | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/callbacks/lr_logger.py b/pytorch_lightning/callbacks/lr_logger.py index 87953d496b3ad..7ec73b8c88811 100755 --- a/pytorch_lightning/callbacks/lr_logger.py +++ b/pytorch_lightning/callbacks/lr_logger.py @@ -64,7 +64,7 @@ def on_train_start(self, trainer, pl_module): # Initialize for storing values self.lrs = {name: [] for name in names} - def on_batch_start(self, trainer, pl_module): + def on_train_batch_start(self, trainer, pl_module): latest_stat = self._extract_lr(trainer, 'step') if trainer.logger and latest_stat: trainer.logger.log_metrics(latest_stat, step=trainer.global_step) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index e123fe8241d8a..779077c437585 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -153,8 +153,8 @@ class CurrentProgressBar(ProgressBar): val_batches_seen = 0 test_batches_seen = 0 - def on_batch_start(self, trainer, pl_module): - super().on_batch_start(trainer, pl_module) + def on_train_batch_start(self, trainer, pl_module): + super().on_train_batch_start(trainer, pl_module) assert self.train_batch_idx == trainer.batch_idx def on_train_batch_end(self, trainer, pl_module): diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index 3afa1dd11c56c..5bd81d7116948 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -214,7 +214,7 @@ class RankZeroLoggerCheck(Callback): # this class has to be defined outside the test function, otherwise we get pickle error # due to the way ddp process is launched - def on_batch_start(self, trainer, pl_module): + def on_train_batch_start(self, trainer, pl_module): is_dummy = isinstance(trainer.logger.experiment, DummyExperiment) if trainer.is_global_zero: assert not is_dummy diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index c7652ebecf3f9..3dbb7b7c079d6 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -377,7 +377,7 @@ def increment_on_load_checkpoint(self, _): # Bind methods to keep track of epoch numbers, batch numbers it has seen # as well as number of times it has called on_load_checkpoint() model.on_epoch_end = types.MethodType(increment_epoch, model) - model.on_batch_start = types.MethodType(increment_batch, model) + model.on_train_batch_start = types.MethodType(increment_batch, model) model.on_load_checkpoint = types.MethodType(increment_on_load_checkpoint, model) return model @@ -691,7 +691,7 @@ class InterruptCallback(Callback): def __init__(self): super().__init__() - def on_batch_start(self, trainer, pl_module): + def on_train_batch_start(self, trainer, pl_module): raise KeyboardInterrupt class HandleInterruptCallback(Callback): diff --git a/tests/utilities/test_dtype_device_mixin.py b/tests/utilities/test_dtype_device_mixin.py index f755cf5c634ed..08f808bda9ceb 100644 --- a/tests/utilities/test_dtype_device_mixin.py +++ b/tests/utilities/test_dtype_device_mixin.py @@ -27,7 +27,7 @@ def __init__(self, *args, **kwargs): class DeviceAssertCallback(Callback): - def on_batch_start(self, trainer, model): + def on_train_batch_start(self, trainer, model): rank = trainer.local_rank assert isinstance(model, TopModule) # index = None also means first device From 4ca780a1d1e79cd83657fe0abd09b2d7fc652894 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 5 Aug 2020 19:31:44 -0400 Subject: [PATCH 13/13] modified hook --- pytorch_lightning/trainer/callback_hook.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index e703ea33b3742..7c62743455317 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -9,7 +9,7 @@ class TrainerCallbackHookMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class callbacks: List[Callback] = [] - get_model: Callable = ... + get_model: Callable def setup(self, stage: str): """Called in the beginning of fit and test"""