Skip to content

Commit 0fbfbf9

Browse files
authored
Make tbptt imports Python 3.10 compatible (#13973)
* Make tbptt imports Python 3.10 compatible * add chlog
1 parent 2919dcf commit 0fbfbf9

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
405405
- Fixed default `amp_level` for `DeepSpeedPrecisionPlugin` to `O2` ([#13897](https://github.com/PyTorchLightning/pytorch-lightning/pull/13897))
406406

407407

408+
- Fixed Python 3.10 compatibility for truncated back-propagation through time (TBPTT) ([#13973](https://github.com/Lightning-AI/lightning/pull/13973))
409+
410+
408411

409412
## [1.6.5] - 2022-07-13
410413

src/pytorch_lightning/core/module.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
"""The LightningModule - an nn.Module with many additional features."""
1515

16-
import collections
16+
import collections.abc
1717
import inspect
1818
import logging
1919
import numbers
@@ -1712,7 +1712,7 @@ def tbptt_split_batch(self, batch, split_size):
17121712
for i, x in enumerate(batch):
17131713
if isinstance(x, torch.Tensor):
17141714
split_x = x[:, t:t + split_size]
1715-
elif isinstance(x, collections.Sequence):
1715+
elif isinstance(x, collections.abc.Sequence):
17161716
split_x = [None] * len(x)
17171717
for batch_idx in range(len(x)):
17181718
split_x[batch_idx] = x[batch_idx][t:t + split_size]
@@ -1726,7 +1726,7 @@ def tbptt_split_batch(self, batch, split_size):
17261726
if :paramref:`~pytorch_lightning.core.module.LightningModule.truncated_bptt_steps` > 0.
17271727
Each returned batch split is passed separately to :meth:`training_step`.
17281728
"""
1729-
time_dims = [len(x[0]) for x in batch if isinstance(x, (Tensor, collections.Sequence))]
1729+
time_dims = [len(x[0]) for x in batch if isinstance(x, (Tensor, collections.abc.Sequence))]
17301730
assert len(time_dims) >= 1, "Unable to determine batch time dimension"
17311731
assert all(x == time_dims[0] for x in time_dims), "Batch time dimension length is ambiguous"
17321732

@@ -1736,7 +1736,7 @@ def tbptt_split_batch(self, batch, split_size):
17361736
for i, x in enumerate(batch):
17371737
if isinstance(x, Tensor):
17381738
split_x = x[:, t : t + split_size]
1739-
elif isinstance(x, collections.Sequence):
1739+
elif isinstance(x, collections.abc.Sequence):
17401740
split_x = [None] * len(x)
17411741
for batch_idx in range(len(x)):
17421742
split_x[batch_idx] = x[batch_idx][t : t + split_size]

0 commit comments

Comments
 (0)