Skip to content

Commit b6c73b4

Browse files
maxjeblickBordaSkafteNickiawaelchlirohitgr7
committed
Find parameters which are specified in the LightningDataModule, only (#4347)
* search for attribute in datamodule if not found elsewhere * add test for datamodule * add lightning_getattr test for datamodule * Apply suggestions from code review Co-authored-by: Adrian Wälchli <[email protected]> * Update CHANGELOG.md * Update CHANGELOG.md Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Nicki Skafte <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Rohit Gupta <[email protected]>
1 parent 30af4fc commit b6c73b4

File tree

3 files changed

+38
-15
lines changed

3 files changed

+38
-15
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5656
- Fixed metrics states being overridden in ddp mode ([#4482](https://github.com/PyTorchLightning/pytorch-lightning/pull/4482))
5757

5858

59+
- Fixed `lightning_getattr`, `lightning_hasattr` not finding the correct attributes in datamodule ([#4347](https://github.com/PyTorchLightning/pytorch-lightning/pull/4347))
60+
61+
5962
## [1.0.5] - 2020-11-03
6063

6164
### Added

pytorch_lightning/utilities/parsing.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,9 @@ def __repr__(self):
177177
def lightning_hasattr(model, attribute):
178178
""" Special hasattr for lightning. Checks for attribute in model namespace,
179179
the old hparams namespace/dict, and the datamodule. """
180-
trainer = model.trainer
180+
trainer = getattr(model, 'trainer', None)
181181

182+
attr = False
182183
# Check if attribute in model
183184
if hasattr(model, attribute):
184185
attr = True
@@ -189,29 +190,25 @@ def lightning_hasattr(model, attribute):
189190
else:
190191
attr = hasattr(model.hparams, attribute)
191192
# Check if the attribute in datamodule (datamodule gets registered in Trainer)
192-
elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
193-
attr = getattr(trainer.datamodule, attribute)
194-
else:
195-
attr = False
193+
if not attr and trainer is not None:
194+
attr = hasattr(trainer.datamodule, attribute)
196195

197196
return attr
198197

199198

200199
def lightning_getattr(model, attribute):
201200
""" Special getattr for lightning. Checks for attribute in model namespace,
202201
the old hparams namespace/dict, and the datamodule. """
203-
trainer = model.trainer
202+
trainer = getattr(model, 'trainer', None)
204203

205204
# Check if attribute in model
206205
if hasattr(model, attribute):
207206
attr = getattr(model, attribute)
208207
# Check if attribute in model.hparams, either namespace or dict
209-
elif hasattr(model, 'hparams'):
210-
if isinstance(model.hparams, dict):
211-
attr = model.hparams[attribute]
212-
else:
213-
attr = getattr(model.hparams, attribute)
214-
208+
elif hasattr(model, 'hparams') and isinstance(model.hparams, dict) and attribute in model.hparams:
209+
attr = model.hparams[attribute]
210+
elif hasattr(model, 'hparams') and hasattr(model.hparams, attribute):
211+
attr = getattr(model.hparams, attribute)
215212
# Check if the attribute in datamodule (datamodule gets registered in Trainer)
216213
elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
217214
attr = getattr(trainer.datamodule, attribute)
@@ -230,7 +227,7 @@ def lightning_setattr(model, attribute, value):
230227
raise ValueError(f'{attribute} is neither stored in the model namespace'
231228
' nor the `hparams` namespace/dict, nor the datamodule.')
232229

233-
trainer = model.trainer
230+
trainer = getattr(model, 'trainer', None)
234231

235232
# Check if attribute in model
236233
if hasattr(model, attribute):

tests/utilities/parsing.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class TestHparamsNamespace:
2424

2525
class TestModel1: # test for namespace
2626
learning_rate = 0
27+
2728
model1 = TestModel1()
2829

2930
class TestModel2: # test for hparams namespace
@@ -41,12 +42,23 @@ class TestModel4: # fail case
4142

4243
model4 = TestModel4()
4344

44-
return model1, model2, model3, model4
45+
class DataModule:
46+
batch_size = 8
47+
48+
class Trainer:
49+
datamodule = DataModule
50+
51+
class TestModel5: # test for datamodule
52+
trainer = Trainer
53+
54+
model5 = TestModel5()
55+
56+
return model1, model2, model3, model4, model5
4557

4658

4759
def test_lightning_hasattr(tmpdir):
4860
""" Test that the lightning_hasattr works in all cases"""
49-
model1, model2, model3, model4 = _get_test_cases()
61+
model1, model2, model3, model4, model5 = _get_test_cases()
5062
assert lightning_hasattr(model1, 'learning_rate'), \
5163
'lightning_hasattr failed to find namespace variable'
5264
assert lightning_hasattr(model2, 'learning_rate'), \
@@ -55,6 +67,8 @@ def test_lightning_hasattr(tmpdir):
5567
'lightning_hasattr failed to find hparams dict variable'
5668
assert not lightning_hasattr(model4, 'learning_rate'), \
5769
'lightning_hasattr found variable when it should not'
70+
assert lightning_hasattr(model5, 'batch_size'), \
71+
'lightning_hasattr failed to find batch_size in datamodule'
5872

5973

6074
def test_lightning_getattr(tmpdir):
@@ -64,6 +78,10 @@ def test_lightning_getattr(tmpdir):
6478
value = lightning_getattr(m, 'learning_rate')
6579
assert value == i, 'attribute not correctly extracted'
6680

81+
model5 = models[4]
82+
assert lightning_getattr(model5, 'batch_size') == 8, \
83+
'batch_size not correctly extracted'
84+
6785

6886
def test_lightning_setattr(tmpdir):
6987
""" Test that the lightning_setattr works in all cases"""
@@ -72,3 +90,8 @@ def test_lightning_setattr(tmpdir):
7290
lightning_setattr(m, 'learning_rate', 10)
7391
assert lightning_getattr(m, 'learning_rate') == 10, \
7492
'attribute not correctly set'
93+
94+
model5 = models[4]
95+
lightning_setattr(model5, 'batch_size', 128)
96+
assert lightning_getattr(model5, 'batch_size') == 128, \
97+
'batch_size not correctly set'

0 commit comments

Comments
 (0)