@@ -106,7 +106,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
106106 self ._current_hook_fx_name : Optional [str ] = None
107107 self ._current_dataloader_idx : Optional [int ] = None
108108 self ._automatic_optimization : bool = True
109- self ._truncated_bptt_steps : Optional [ int ] = None
109+ self ._truncated_bptt_steps : int = 0
110110 self ._param_requires_grad_state = dict ()
111111
112112 def optimizers (self , use_pl_optimizer : bool = True ) -> Union [Optimizer , List [Optimizer ], List [LightningOptimizer ]]:
@@ -195,15 +195,15 @@ def automatic_optimization(self, automatic_optimization: bool) -> None:
195195 self ._automatic_optimization = automatic_optimization
196196
197197 @property
198- def truncated_bptt_steps (self ) -> Optional [ int ] :
198+ def truncated_bptt_steps (self ) -> int :
199199 """
200200 truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of much a longer sequence.
201201 If this is > 0, the training step is passed ``hiddens``.
202202 """
203203 return self ._truncated_bptt_steps
204204
205205 @truncated_bptt_steps .setter
206- def truncated_bptt_steps (self , truncated_bptt_steps : Optional [ int ] ) -> None :
206+ def truncated_bptt_steps (self , truncated_bptt_steps : int ) -> None :
207207 self ._truncated_bptt_steps = truncated_bptt_steps
208208
209209 @property
@@ -538,9 +538,8 @@ def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
538538 The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
539539 batch_idx (int): Integer displaying index of this batch
540540 optimizer_idx (int): When using multiple optimizers, this argument will also be present.
541- hiddens(:class:`~torch.Tensor`): Passed in if either
541+ hiddens(:class:`~torch.Tensor`): Passed in if
542542 :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0
543- :paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0
544543
545544
546545 Return:
@@ -1462,7 +1461,6 @@ def tbptt_split_batch(self, batch, split_size):
14621461 Called in the training loop after
14631462 :meth:`~pytorch_lightning.callbacks.base.Callback.on_batch_start`
14641463 if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0
1465- or :paramref:`~pytorch_lightning.trainer.Trainer.truncated_bptt_steps` > 0
14661464
14671465 Each returned batch split is passed separately to :meth:`training_step`.
14681466
@@ -1564,7 +1562,7 @@ def get_progress_bar_dict(self):
15641562 if avg_training_loss is not None :
15651563 tqdm_dict ["loss" ] = f"{ avg_training_loss :.3g} "
15661564
1567- module_tbptt_enabled = self .truncated_bptt_steps is not None and self . truncated_bptt_steps > 0
1565+ module_tbptt_enabled = self .truncated_bptt_steps > 0
15681566 trainer_tbptt_enabled = self .trainer .truncated_bptt_steps is not None and self .trainer .truncated_bptt_steps > 0
15691567 if module_tbptt_enabled or trainer_tbptt_enabled :
15701568 tqdm_dict ["split_idx" ] = self .trainer .split_idx
0 commit comments