Skip to content

Commit 98670c8

Browse files
ananthsubBorda
andauthored
Deprecatetruncated_bptt_steps flag on Trainer in favor of same setting on the LightningModule (#7323)
* deprecate-tbptt-trainer * Update CHANGELOG.md * Update lightning.py * test * Update lightning.py * Update training_loop.py * Update training_loop.py * Update lightning.py * Update training_loop.py * Update training_loop.py * update docs * Update accelerator.py * Update accelerator.py * more docs * tweaks * chlog * comments Co-authored-by: Jirka Borovec <[email protected]>
1 parent 573a5a8 commit 98670c8

File tree

9 files changed

+143
-28
lines changed

9 files changed

+143
-28
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
157157
### Changed
158158

159159

160+
- Changed `LightningModule.truncated_bptt_steps` to be property ([#7323](https://github.com/PyTorchLightning/pytorch-lightning/pull/7323))
161+
162+
160163
- Changed `EarlyStopping` callback from by default running `EarlyStopping.on_validation_end` if only training is run. Set `check_on_train_epoch_end` to run the callback at the end of the train epoch instead of at the end of the validation epoch ([#7069](https://github.com/PyTorchLightning/pytorch-lightning/pull/7069))
161164

162165

@@ -205,6 +208,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
205208
### Deprecated
206209

207210

211+
- Deprecated `Trainer.truncated_bptt_steps` in favor of `LightningModule.truncated_bptt_steps` ([#7323](https://github.com/PyTorchLightning/pytorch-lightning/pull/7323))
212+
213+
208214
- Deprecated `LightningModule.grad_norm` in favor of `pytorch_lightning.utilities.grads.grad_norm` ([#7292](https://github.com/PyTorchLightning/pytorch-lightning/pull/7292))
209215

210216

docs/source/advanced/sequences.rst

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,20 +40,31 @@ For example, it may save memory to use Truncated Backpropagation Through Time wh
4040

4141
Lightning can handle TBTT automatically via this flag.
4242

43-
.. testcode::
43+
.. testcode:: python
44+
45+
from pytorch_lightning import LightningModule
4446

45-
# DEFAULT (single backwards pass per batch)
46-
trainer = Trainer(truncated_bptt_steps=None)
47+
class MyModel(LightningModule):
4748

48-
# (split batch into sequences of size 2)
49-
trainer = Trainer(truncated_bptt_steps=2)
49+
def __init__(self):
50+
super().__init__()
51+
# Important: This property activates truncated backpropagation through time
52+
# Setting this value to 2 splits the batch into sequences of size 2
53+
self.truncated_bptt_steps = 2
54+
55+
# Truncated back-propagation through time
56+
def training_step(self, batch, batch_idx, hiddens):
57+
# the training step must be updated to accept a ``hiddens`` argument
58+
# hiddens are the hiddens from the previous truncated backprop step
59+
out, hiddens = self.lstm(data, hiddens)
60+
return {
61+
"loss": ...,
62+
"hiddens": hiddens
63+
}
5064

5165
.. note:: If you need to modify how the batch is split,
5266
override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`.
5367

54-
.. note:: Using this feature requires updating your LightningModule's
55-
:meth:`pytorch_lightning.core.LightningModule.training_step` to include a `hiddens` arg.
56-
5768
----------
5869

5970
Iterable Datasets

docs/source/common/lightning_module.rst

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,6 +1005,63 @@ Get the model file size (in megabytes) using ``self.model_size`` inside Lightnin
10051005

10061006
--------------
10071007

1008+
truncated_bptt_steps
1009+
^^^^^^^^^^^^^^^^^^^^
1010+
1011+
Truncated back prop breaks performs backprop every k steps of
1012+
a much longer sequence.
1013+
1014+
If this is enabled, your batches will automatically get truncated
1015+
and the trainer will apply Truncated Backprop to it.
1016+
1017+
(`Williams et al. "An efficient gradient-based algorithm for on-line training of
1018+
recurrent network trajectories."
1019+
<http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.56.7941&rep=rep1&type=pdf>`_)
1020+
1021+
`Tutorial <https://d2l.ai/chapter_recurrent-neural-networks/bptt.html>`_
1022+
1023+
.. testcode:: python
1024+
1025+
from pytorch_lightning import LightningModule
1026+
1027+
class MyModel(LightningModule):
1028+
1029+
def __init__(self):
1030+
super().__init__()
1031+
# Important: This property activates truncated backpropagation through time
1032+
# Setting this value to 2 splits the batch into sequences of size 2
1033+
self.truncated_bptt_steps = 2
1034+
1035+
# Truncated back-propagation through time
1036+
def training_step(self, batch, batch_idx, hiddens):
1037+
# the training step must be updated to accept a ``hiddens`` argument
1038+
# hiddens are the hiddens from the previous truncated backprop step
1039+
out, hiddens = self.lstm(data, hiddens)
1040+
return {
1041+
"loss": ...,
1042+
"hiddens": hiddens
1043+
}
1044+
1045+
Lightning takes care to split your batch along the time-dimension.
1046+
1047+
.. code-block:: python
1048+
1049+
# we use the second as the time dimension
1050+
# (batch, time, ...)
1051+
sub_batch = batch[0, 0:t, ...]
1052+
1053+
To modify how the batch is split,
1054+
override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`:
1055+
1056+
.. testcode:: python
1057+
1058+
class LitMNIST(LightningModule):
1059+
def tbptt_split_batch(self, batch, split_size):
1060+
# do your own splitting on the batch
1061+
return splits
1062+
1063+
--------------
1064+
10081065
Hooks
10091066
^^^^^
10101067
This is the pseudocode to describe how all the hooks are called during a call to ``.fit()``.

pytorch_lightning/accelerators/accelerator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,7 @@ def training_step(
196196
- batch_idx (int): Integer displaying index of this batch
197197
- optimizer_idx (int): When using multiple optimizers, this argument will also be present.
198198
- hiddens(:class:`~torch.Tensor`): Passed in if
199-
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.
200-
199+
:paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
201200
"""
202201
args[0] = self.to_device(args[0])
203202

pytorch_lightning/core/lightning.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class LightningModule(
5959
Module,
6060
):
6161
# Below is for property support of JIT in PyTorch 1.7
62-
# since none of them is important when using JIT, we are going to ignore them.
62+
# since none of these are important when using JIT, we are going to ignore them.
6363
__jit_unused_properties__ = [
6464
"datamodule",
6565
"example_input_array",
@@ -72,6 +72,8 @@ class LightningModule(
7272
"local_rank",
7373
"logger",
7474
"model_size",
75+
"automatic_optimization",
76+
"truncated_bptt_steps",
7577
] + DeviceDtypeModuleMixin.__jit_unused_properties__
7678

7779
def __init__(self, *args: Any, **kwargs: Any) -> None:
@@ -104,6 +106,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
104106
self._current_hook_fx_name: Optional[str] = None
105107
self._current_dataloader_idx: Optional[int] = None
106108
self._automatic_optimization: bool = True
109+
self._truncated_bptt_steps: int = 0
107110
self._param_requires_grad_state = dict()
108111

109112
def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]:
@@ -191,6 +194,18 @@ def automatic_optimization(self) -> bool:
191194
def automatic_optimization(self, automatic_optimization: bool) -> None:
192195
self._automatic_optimization = automatic_optimization
193196

197+
@property
198+
def truncated_bptt_steps(self) -> int:
199+
"""
200+
truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of much a longer sequence.
201+
If this is > 0, the training step is passed ``hiddens``.
202+
"""
203+
return self._truncated_bptt_steps
204+
205+
@truncated_bptt_steps.setter
206+
def truncated_bptt_steps(self, truncated_bptt_steps: int) -> None:
207+
self._truncated_bptt_steps = truncated_bptt_steps
208+
194209
@property
195210
def logger(self):
196211
""" Reference to the logger object in the Trainer. """
@@ -524,7 +539,7 @@ def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
524539
batch_idx (int): Integer displaying index of this batch
525540
optimizer_idx (int): When using multiple optimizers, this argument will also be present.
526541
hiddens(:class:`~torch.Tensor`): Passed in if
527-
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.
542+
:paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
528543
529544
Return:
530545
Any of.
@@ -1469,7 +1484,7 @@ def tbptt_split_batch(self, batch, split_size):
14691484
Note:
14701485
Called in the training loop after
14711486
:meth:`~pytorch_lightning.callbacks.base.Callback.on_batch_start`
1472-
if :paramref:`~pytorch_lightning.trainer.Trainer.truncated_bptt_steps` > 0.
1487+
if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
14731488
Each returned batch split is passed separately to :meth:`training_step`.
14741489
14751490
"""
@@ -1570,7 +1585,9 @@ def get_progress_bar_dict(self):
15701585
if avg_training_loss is not None:
15711586
tqdm_dict["loss"] = f"{avg_training_loss:.3g}"
15721587

1573-
if self.trainer.truncated_bptt_steps is not None:
1588+
module_tbptt_enabled = self.truncated_bptt_steps > 0
1589+
trainer_tbptt_enabled = self.trainer.truncated_bptt_steps is not None and self.trainer.truncated_bptt_steps > 0
1590+
if module_tbptt_enabled or trainer_tbptt_enabled:
15741591
tqdm_dict["split_idx"] = self.trainer.split_idx
15751592

15761593
if self.trainer.logger is not None and self.trainer.logger.version is not None:

pytorch_lightning/trainer/connectors/training_trick_connector.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from typing import Dict, List, Optional, Union
15+
1416
from pytorch_lightning.callbacks import GradientAccumulationScheduler
1517
from pytorch_lightning.utilities import GradClipAlgorithmType
18+
from pytorch_lightning.utilities.distributed import rank_zero_deprecation
1619
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1720

1821

@@ -23,12 +26,12 @@ def __init__(self, trainer):
2326

2427
def on_trainer_init(
2528
self,
26-
gradient_clip_val,
27-
gradient_clip_algorithm,
28-
track_grad_norm,
29-
accumulate_grad_batches,
30-
truncated_bptt_steps,
31-
terminate_on_nan,
29+
gradient_clip_val: float,
30+
gradient_clip_algorithm: str,
31+
track_grad_norm: Union[int, float, str],
32+
accumulate_grad_batches: Union[int, Dict[int, int], List[list]],
33+
truncated_bptt_steps: Optional[int],
34+
terminate_on_nan: bool,
3235
):
3336

3437
self.trainer.terminate_on_nan = terminate_on_nan
@@ -48,6 +51,11 @@ def on_trainer_init(
4851
self.trainer.accumulate_grad_batches = accumulate_grad_batches
4952
self.configure_accumulated_gradients(accumulate_grad_batches)
5053

54+
if truncated_bptt_steps is not None and truncated_bptt_steps > 0:
55+
rank_zero_deprecation(
56+
"Trainer.truncated_bptt_steps is deprecated in v1.3 and will be removed in v1.5."
57+
" Set truncated_bptt_steps directly on the LightningModule instead."
58+
)
5159
self.trainer.truncated_bptt_steps = truncated_bptt_steps
5260

5361
def configure_accumulated_gradients(self, accumulate_grad_batches):

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,8 +280,8 @@ def __init__(
280280
281281
track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm.
282282
283-
truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of much longer
284-
sequence.
283+
truncated_bptt_steps: Deprecated in v1.3 to be removed in 1.5.
284+
Please use :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` instead.
285285
286286
val_check_interval: How often to check the validation set. Use float to check within a training epoch,
287287
use int to check every n steps (batches).

pytorch_lightning/trainer/training_loop.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from contextlib import contextmanager, suppress
1616
from copy import copy, deepcopy
17-
from typing import Dict, List, Optional, Union
17+
from typing import Any, Dict, List, Optional, Union
1818

1919
import numpy as np
2020
import torch
@@ -441,12 +441,13 @@ def _track_gradient_norm(self):
441441
grad_norm_dict = grad_norm(model, self.trainer.track_grad_norm)
442442
return grad_norm_dict
443443

444-
def tbptt_split_batch(self, batch):
444+
def _tbptt_split_batch(self, batch: Any) -> List[Any]:
445445
splits = [batch]
446-
if self.trainer.truncated_bptt_steps is not None:
446+
truncated_bptt_enabled = self._truncated_bptt_enabled()
447+
if truncated_bptt_enabled:
447448
model_ref = self.trainer.lightning_module
448449
with self.trainer.profiler.profile("tbptt_split_batch"):
449-
splits = model_ref.tbptt_split_batch(batch, self.trainer.truncated_bptt_steps)
450+
splits = model_ref.tbptt_split_batch(batch, self._truncated_bptt_steps())
450451
return splits
451452

452453
def run_training_epoch(self):
@@ -626,7 +627,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
626627
return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic)
627628

628629
# lightning module hook
629-
splits = self.tbptt_split_batch(batch)
630+
splits = self._tbptt_split_batch(batch)
630631

631632
for split_idx, split_batch in enumerate(splits):
632633

@@ -896,11 +897,22 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
896897
)
897898

898899
# pass hiddens if using tbptt
899-
if self.trainer.truncated_bptt_steps is not None:
900+
if self._truncated_bptt_enabled():
900901
args.append(hiddens)
901902

902903
return args
903904

905+
def _truncated_bptt_enabled(self) -> bool:
906+
""" Temporary tbptt utilities until this flag is fully migrated to the lightning module. """
907+
return self._truncated_bptt_steps() > 0
908+
909+
def _truncated_bptt_steps(self) -> int:
910+
lightning_module = self.trainer.lightning_module
911+
# Give precedence to the LightningModule as the Trainer flag will be removed in v1.5
912+
if lightning_module.truncated_bptt_steps > 0:
913+
return lightning_module.truncated_bptt_steps
914+
return self.trainer.truncated_bptt_steps or 0
915+
904916
def save_loggers_on_train_batch_end(self):
905917
# when loggers should save to disk
906918
should_flush_logs = self.trainer.logger_connector.should_flush_logs

tests/deprecated_api/test_remove_1-5.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,3 +412,8 @@ def test_v1_5_0_datamodule_setter():
412412
model.datamodule = datamodule
413413
with pytest.deprecated_call(match="The `LightningModule.datamodule`"):
414414
_ = model.datamodule
415+
416+
417+
def test_v1_5_0_trainer_tbptt_steps(tmpdir):
418+
with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"):
419+
_ = Trainer(truncated_bptt_steps=1)

0 commit comments

Comments
 (0)