From bcfdbdff60a838fe566141dfc9e5f4817a2e3fc7 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 2 Aug 2022 03:15:57 +0200 Subject: [PATCH 1/2] Make tbptt imports Python 3.10 compatible --- src/pytorch_lightning/core/module.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index a66c7679b3ee0..b8cc1d91cde18 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -13,7 +13,7 @@ # limitations under the License. """The LightningModule - an nn.Module with many additional features.""" -import collections +import collections.abc import inspect import logging import numbers @@ -1712,7 +1712,7 @@ def tbptt_split_batch(self, batch, split_size): for i, x in enumerate(batch): if isinstance(x, torch.Tensor): split_x = x[:, t:t + split_size] - elif isinstance(x, collections.Sequence): + elif isinstance(x, collections.abc.Sequence): split_x = [None] * len(x) for batch_idx in range(len(x)): split_x[batch_idx] = x[batch_idx][t:t + split_size] @@ -1726,7 +1726,7 @@ def tbptt_split_batch(self, batch, split_size): if :paramref:`~pytorch_lightning.core.module.LightningModule.truncated_bptt_steps` > 0. Each returned batch split is passed separately to :meth:`training_step`. """ - time_dims = [len(x[0]) for x in batch if isinstance(x, (Tensor, collections.Sequence))] + time_dims = [len(x[0]) for x in batch if isinstance(x, (Tensor, collections.abc.Sequence))] assert len(time_dims) >= 1, "Unable to determine batch time dimension" assert all(x == time_dims[0] for x in time_dims), "Batch time dimension length is ambiguous" @@ -1736,7 +1736,7 @@ def tbptt_split_batch(self, batch, split_size): for i, x in enumerate(batch): if isinstance(x, Tensor): split_x = x[:, t : t + split_size] - elif isinstance(x, collections.Sequence): + elif isinstance(x, collections.abc.Sequence): split_x = [None] * len(x) for batch_idx in range(len(x)): split_x[batch_idx] = x[batch_idx][t : t + split_size] From e86b08187c257a6b57554719a2fe7e2c5e74f7e0 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 2 Aug 2022 03:21:35 +0200 Subject: [PATCH 2/2] add chlog --- src/pytorch_lightning/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 1516b74453842..cc7fe7b2c0d84 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -402,6 +402,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed default `amp_level` for `DeepSpeedPrecisionPlugin` to `O2` ([#13897](https://github.com/PyTorchLightning/pytorch-lightning/pull/13897)) +- Fixed Python 3.10 compatibility for truncated back-propagation through time (TBPTT) ([#13973](https://github.com/Lightning-AI/lightning/pull/13973)) + + ## [1.6.5] - 2022-07-13