Skip to content

Commit 020adeb

Browse files
committed
update docs
1 parent 074818d commit 020adeb

File tree

5 files changed

+26
-19
lines changed

5 files changed

+26
-19
lines changed

docs/source/advanced/sequences.rst

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,21 @@ For example, it may save memory to use Truncated Backpropagation Through Time wh
4040

4141
Lightning can handle TBTT automatically via this flag.
4242

43-
.. testcode::
43+
.. testcode:: python
44+
45+
from pytorch_lightning import LightningModule
46+
47+
class MyModel(LightningModule):
4448

45-
# DEFAULT (single backwards pass per batch)
46-
trainer = Trainer(truncated_bptt_steps=None)
49+
def __init__(self):
50+
super().__init__()
51+
# Important: This property activates truncated backpropagation through time
52+
# Setting this value to 2 splits the batch into sequences of size 2
53+
self.truncated_bptt_steps = 2
4754

48-
# (split batch into sequences of size 2)
49-
trainer = Trainer(truncated_bptt_steps=2)
55+
def training_step(batch, batch_idx, hiddens):
56+
# The training_step will be passed a `hiddens` argument for the split batch
57+
...
5058

5159
.. note:: If you need to modify how the batch is split,
5260
override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`.

pytorch_lightning/accelerators/accelerator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def training_step(
196196
- batch_idx (int): Integer displaying index of this batch
197197
- optimizer_idx (int): When using multiple optimizers, this argument will also be present.
198198
- hiddens(:class:`~torch.Tensor`): Passed in if
199-
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.
199+
:paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0
200200
201201
"""
202202
args[0] = self.to_device(args[0])

pytorch_lightning/core/lightning.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
106106
self._current_hook_fx_name: Optional[str] = None
107107
self._current_dataloader_idx: Optional[int] = None
108108
self._automatic_optimization: bool = True
109-
self._truncated_bptt_steps: Optional[int] = None
109+
self._truncated_bptt_steps: int = 0
110110
self._param_requires_grad_state = dict()
111111

112112
def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]:
@@ -195,15 +195,15 @@ def automatic_optimization(self, automatic_optimization: bool) -> None:
195195
self._automatic_optimization = automatic_optimization
196196

197197
@property
198-
def truncated_bptt_steps(self) -> Optional[int]:
198+
def truncated_bptt_steps(self) -> int:
199199
"""
200200
truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of much a longer sequence.
201201
If this is > 0, the training step is passed ``hiddens``.
202202
"""
203203
return self._truncated_bptt_steps
204204

205205
@truncated_bptt_steps.setter
206-
def truncated_bptt_steps(self, truncated_bptt_steps: Optional[int]) -> None:
206+
def truncated_bptt_steps(self, truncated_bptt_steps: int) -> None:
207207
self._truncated_bptt_steps = truncated_bptt_steps
208208

209209
@property
@@ -538,9 +538,8 @@ def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
538538
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
539539
batch_idx (int): Integer displaying index of this batch
540540
optimizer_idx (int): When using multiple optimizers, this argument will also be present.
541-
hiddens(:class:`~torch.Tensor`): Passed in if either
541+
hiddens(:class:`~torch.Tensor`): Passed in if
542542
:paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0
543-
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0
544543
545544
546545
Return:
@@ -1462,7 +1461,6 @@ def tbptt_split_batch(self, batch, split_size):
14621461
Called in the training loop after
14631462
:meth:`~pytorch_lightning.callbacks.base.Callback.on_batch_start`
14641463
if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0
1465-
or :paramref:`~pytorch_lightning.trainer.Trainer.truncated_bptt_steps` > 0
14661464
14671465
Each returned batch split is passed separately to :meth:`training_step`.
14681466
@@ -1564,7 +1562,7 @@ def get_progress_bar_dict(self):
15641562
if avg_training_loss is not None:
15651563
tqdm_dict["loss"] = f"{avg_training_loss:.3g}"
15661564

1567-
module_tbptt_enabled = self.truncated_bptt_steps is not None and self.truncated_bptt_steps > 0
1565+
module_tbptt_enabled = self.truncated_bptt_steps > 0
15681566
trainer_tbptt_enabled = self.trainer.truncated_bptt_steps is not None and self.trainer.truncated_bptt_steps > 0
15691567
if module_tbptt_enabled or trainer_tbptt_enabled:
15701568
tqdm_dict["split_idx"] = self.trainer.split_idx

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,8 @@ def __init__(
279279
280280
track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm.
281281
282-
truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of much longer
283-
sequence. This argument has been moved to LightningModule. It is deprecated here in v1.3 and
284-
will be removed in v1.5.
282+
truncated_bptt_steps: Deprecated in v1.3 to be removed in 1.5.
283+
Please use :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` instead.
285284
286285
val_check_interval: How often to check the validation set. Use float to check within a training epoch,
287286
use int to check every n steps (batches).

pytorch_lightning/trainer/training_loop.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -885,16 +885,18 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
885885
def _truncated_bptt_enabled(self) -> bool:
886886
""" Temporary tbptt utilities until this flag is fully migrated to the lightning module. """
887887
module = self.trainer.lightning_module
888-
module_tbptt_enabled = module.truncated_bptt_steps is not None and module.truncated_bptt_steps > 0
888+
module_tbptt_enabled = module.truncated_bptt_steps > 0
889+
if module_tbptt_enabled:
890+
return True
889891

890892
trainer = self.trainer
891893
trainer_tbptt_enabled = trainer.truncated_bptt_steps is not None and trainer.truncated_bptt_steps > 0
892-
return module_tbptt_enabled or trainer_tbptt_enabled
894+
return trainer_tbptt_enabled
893895

894896
def _truncated_bptt_steps(self) -> Optional[int]:
895897
lightning_module = self.trainer.lightning_module
896898
# Give precedence to the LightningModule as the Trainer flag will be removed in v1.5
897-
if lightning_module.truncated_bptt_steps is not None and lightning_module.truncated_bptt_steps > 0:
899+
if lightning_module.truncated_bptt_steps > 0:
898900
return lightning_module.truncated_bptt_steps
899901
return self.trainer.truncated_bptt_steps
900902

0 commit comments

Comments
 (0)