diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 1516b74453842..cc7fe7b2c0d84 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -402,6 +402,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed default `amp_level` for `DeepSpeedPrecisionPlugin` to `O2` ([#13897](https://github.com/PyTorchLightning/pytorch-lightning/pull/13897)) +- Fixed Python 3.10 compatibility for truncated back-propagation through time (TBPTT) ([#13973](https://github.com/Lightning-AI/lightning/pull/13973)) + + ## [1.6.5] - 2022-07-13 diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index a66c7679b3ee0..b8cc1d91cde18 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -13,7 +13,7 @@ # limitations under the License. """The LightningModule - an nn.Module with many additional features.""" -import collections +import collections.abc import inspect import logging import numbers @@ -1712,7 +1712,7 @@ def tbptt_split_batch(self, batch, split_size): for i, x in enumerate(batch): if isinstance(x, torch.Tensor): split_x = x[:, t:t + split_size] - elif isinstance(x, collections.Sequence): + elif isinstance(x, collections.abc.Sequence): split_x = [None] * len(x) for batch_idx in range(len(x)): split_x[batch_idx] = x[batch_idx][t:t + split_size] @@ -1726,7 +1726,7 @@ def tbptt_split_batch(self, batch, split_size): if :paramref:`~pytorch_lightning.core.module.LightningModule.truncated_bptt_steps` > 0. Each returned batch split is passed separately to :meth:`training_step`. """ - time_dims = [len(x[0]) for x in batch if isinstance(x, (Tensor, collections.Sequence))] + time_dims = [len(x[0]) for x in batch if isinstance(x, (Tensor, collections.abc.Sequence))] assert len(time_dims) >= 1, "Unable to determine batch time dimension" assert all(x == time_dims[0] for x in time_dims), "Batch time dimension length is ambiguous" @@ -1736,7 +1736,7 @@ def tbptt_split_batch(self, batch, split_size): for i, x in enumerate(batch): if isinstance(x, Tensor): split_x = x[:, t : t + split_size] - elif isinstance(x, collections.Sequence): + elif isinstance(x, collections.abc.Sequence): split_x = [None] * len(x) for batch_idx in range(len(x)): split_x[batch_idx] = x[batch_idx][t : t + split_size]