Skip to content

Commit 24f0124

Browse files
authored
Deprecate DataModule properties: train_transforms, val_transforms, test_transforms, dims, and size (#8851)
* Deprecate DataModule properties: train_transforms, val_transforms, test_transforms, dims, and size
1 parent b47e3ab commit 24f0124

File tree

3 files changed

+88
-1
lines changed

3 files changed

+88
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8181
- Deprecated `LightningModule.model_size` ([#8343](https://github.com/PyTorchLightning/pytorch-lightning/pull/8343))
8282

8383

84-
-
84+
- Deprecated `DataModule` properties: `train_transforms`, `val_transforms`, `test_transforms`, `size`, `dims` ([#8851](https://github.com/PyTorchLightning/pytorch-lightning/pull/8851))
8585

8686

8787
-

pytorch_lightning/core/datamodule.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,20 @@ def teardown(self):
7070

7171
def __init__(self, train_transforms=None, val_transforms=None, test_transforms=None, dims=None):
7272
super().__init__()
73+
if train_transforms is not None:
74+
rank_zero_deprecation(
75+
"DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7."
76+
)
77+
if val_transforms is not None:
78+
rank_zero_deprecation(
79+
"DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7."
80+
)
81+
if test_transforms is not None:
82+
rank_zero_deprecation(
83+
"DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7."
84+
)
85+
if dims is not None:
86+
rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")
7387
self._train_transforms = train_transforms
7488
self._val_transforms = val_transforms
7589
self._test_transforms = test_transforms
@@ -95,55 +109,94 @@ def __init__(self, train_transforms=None, val_transforms=None, test_transforms=N
95109
def train_transforms(self):
96110
"""
97111
Optional transforms (or collection of transforms) you can apply to train dataset
112+
113+
.. deprecated:: v1.5
114+
Will be removed in v1.7.0.
98115
"""
116+
117+
rank_zero_deprecation(
118+
"DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7."
119+
)
99120
return self._train_transforms
100121

101122
@train_transforms.setter
102123
def train_transforms(self, t):
124+
rank_zero_deprecation(
125+
"DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7."
126+
)
103127
self._train_transforms = t
104128

105129
@property
106130
def val_transforms(self):
107131
"""
108132
Optional transforms (or collection of transforms) you can apply to validation dataset
133+
134+
.. deprecated:: v1.5
135+
Will be removed in v1.7.0.
109136
"""
137+
138+
rank_zero_deprecation(
139+
"DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7."
140+
)
110141
return self._val_transforms
111142

112143
@val_transforms.setter
113144
def val_transforms(self, t):
145+
rank_zero_deprecation(
146+
"DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7."
147+
)
114148
self._val_transforms = t
115149

116150
@property
117151
def test_transforms(self):
118152
"""
119153
Optional transforms (or collection of transforms) you can apply to test dataset
154+
155+
.. deprecated:: v1.5
156+
Will be removed in v1.7.0.
120157
"""
158+
159+
rank_zero_deprecation(
160+
"DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7."
161+
)
121162
return self._test_transforms
122163

123164
@test_transforms.setter
124165
def test_transforms(self, t):
166+
rank_zero_deprecation(
167+
"DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7."
168+
)
125169
self._test_transforms = t
126170

127171
@property
128172
def dims(self):
129173
"""
130174
A tuple describing the shape of your data. Extra functionality exposed in ``size``.
175+
176+
.. deprecated:: v1.5
177+
Will be removed in v1.7.0.
131178
"""
179+
rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")
132180
return self._dims
133181

134182
@dims.setter
135183
def dims(self, d):
184+
rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")
136185
self._dims = d
137186

138187
def size(self, dim=None) -> Union[Tuple, int]:
139188
"""
140189
Return the dimension of each input either as a tuple or list of tuples. You can index this
141190
just as you would with a torch tensor.
191+
192+
.. deprecated:: v1.5
193+
Will be removed in v1.7.0.
142194
"""
143195

144196
if dim is not None:
145197
return self.dims[dim]
146198

199+
rank_zero_deprecation("DataModule property `size` was deprecated in v1.5 and will be removed in v1.7.")
147200
return self.dims
148201

149202
@property

tests/deprecated_api/test_remove_1-7.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515

1616
import pytest
1717

18+
from pytorch_lightning import LightningDataModule
1819
from tests.deprecated_api import _soft_unimport_module
1920
from tests.helpers import BoringModel
21+
from tests.helpers.datamodules import MNISTDataModule
2022

2123

2224
def test_v1_7_0_deprecated_lightning_module_summarize(tmpdir):
@@ -46,3 +48,35 @@ def test_v1_7_0_deprecated_model_size():
4648
match="LightningModule.model_size` property was deprecated in v1.5 and will be removed in v1.7"
4749
):
4850
_ = model.model_size
51+
52+
53+
def test_v1_7_0_datamodule_transform_properties(tmpdir):
54+
dm = MNISTDataModule()
55+
with pytest.deprecated_call(match=r"DataModule property `train_transforms` was deprecated in v1.5"):
56+
dm.train_transforms = "a"
57+
with pytest.deprecated_call(match=r"DataModule property `val_transforms` was deprecated in v1.5"):
58+
dm.val_transforms = "b"
59+
with pytest.deprecated_call(match=r"DataModule property `test_transforms` was deprecated in v1.5"):
60+
dm.test_transforms = "c"
61+
with pytest.deprecated_call(match=r"DataModule property `train_transforms` was deprecated in v1.5"):
62+
_ = LightningDataModule(train_transforms="a")
63+
with pytest.deprecated_call(match=r"DataModule property `val_transforms` was deprecated in v1.5"):
64+
_ = LightningDataModule(val_transforms="b")
65+
with pytest.deprecated_call(match=r"DataModule property `test_transforms` was deprecated in v1.5"):
66+
_ = LightningDataModule(test_transforms="c")
67+
with pytest.deprecated_call(match=r"DataModule property `test_transforms` was deprecated in v1.5"):
68+
_ = LightningDataModule(test_transforms="c", dims=(1, 1, 1))
69+
70+
71+
def test_v1_7_0_datamodule_size_property(tmpdir):
72+
dm = MNISTDataModule()
73+
with pytest.deprecated_call(match=r"DataModule property `size` was deprecated in v1.5"):
74+
dm.size()
75+
76+
77+
def test_v1_7_0_datamodule_dims_property(tmpdir):
78+
dm = MNISTDataModule()
79+
with pytest.deprecated_call(match=r"DataModule property `dims` was deprecated in v1.5"):
80+
_ = dm.dims
81+
with pytest.deprecated_call(match=r"DataModule property `dims` was deprecated in v1.5"):
82+
_ = LightningDataModule(dims=(1, 1, 1))

0 commit comments

Comments
 (0)