Skip to content

Commit da5ba50

Browse files
rnettBorda
authored andcommitted
Unify attribute finding logic, fix not using dataloader when hparams present (#4559)
* Rebase onto master * indent fix * Remove duplicated logic * Use single return * Remove extra else * add `__contains__` to TestHparamsNamespace to fix tests * Fix lightning_setattr to set all valid attributes * update doc * better names * fix holder order preference * tests for new behavior * Comment about using the last holder Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Sean Naren <[email protected]> (cherry picked from commit eee3b1a)
1 parent cc38d4c commit da5ba50

File tree

2 files changed

+81
-56
lines changed

2 files changed

+81
-56
lines changed

pytorch_lightning/utilities/parsing.py

Lines changed: 46 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -196,76 +196,70 @@ def __repr__(self):
196196
return out
197197

198198

199-
def lightning_hasattr(model, attribute):
200-
""" Special hasattr for lightning. Checks for attribute in model namespace,
201-
the old hparams namespace/dict, and the datamodule. """
199+
def lightning_get_all_attr_holders(model, attribute):
200+
""" Special attribute finding for lightning. Gets all of the objects or dicts that holds attribute.
201+
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """
202202
trainer = getattr(model, 'trainer', None)
203203

204-
attr = False
204+
holders = []
205+
205206
# Check if attribute in model
206207
if hasattr(model, attribute):
207-
attr = True
208+
holders.append(model)
209+
208210
# Check if attribute in model.hparams, either namespace or dict
209-
elif hasattr(model, 'hparams'):
210-
if isinstance(model.hparams, dict):
211-
attr = attribute in model.hparams
212-
else:
213-
attr = hasattr(model.hparams, attribute)
211+
if hasattr(model, 'hparams'):
212+
if attribute in model.hparams:
213+
holders.append(model.hparams)
214+
214215
# Check if the attribute in datamodule (datamodule gets registered in Trainer)
215-
if not attr and trainer is not None:
216-
attr = hasattr(trainer.datamodule, attribute)
216+
if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
217+
holders.append(trainer.datamodule)
217218

218-
return attr
219+
return holders
220+
221+
222+
def lightning_get_first_attr_holder(model, attribute):
223+
""" Special attribute finding for lightning. Gets the object or dict that holds attribute, or None. Checks for attribute in model namespace,
224+
the old hparams namespace/dict, and the datamodule, returns the last one that has it. """
225+
holders = lightning_get_all_attr_holders(model, attribute)
226+
if len(holders) == 0:
227+
return None
228+
# using the last holder to preserve backwards compatibility
229+
return holders[-1]
230+
231+
232+
def lightning_hasattr(model, attribute):
233+
""" Special hasattr for lightning. Checks for attribute in model namespace,
234+
the old hparams namespace/dict, and the datamodule. """
235+
return lightning_get_first_attr_holder(model, attribute) is not None
219236

220237

221238
def lightning_getattr(model, attribute):
222239
""" Special getattr for lightning. Checks for attribute in model namespace,
223240
the old hparams namespace/dict, and the datamodule. """
224-
trainer = getattr(model, 'trainer', None)
241+
holder = lightning_get_first_attr_holder(model, attribute)
242+
if holder is None:
243+
raise ValueError(f'{attribute} is neither stored in the model namespace'
244+
' nor the `hparams` namespace/dict, nor the datamodule.')
225245

226-
# Check if attribute in model
227-
if hasattr(model, attribute):
228-
attr = getattr(model, attribute)
229-
# Check if attribute in model.hparams, either namespace or dict
230-
elif hasattr(model, 'hparams') and isinstance(model.hparams, dict) and attribute in model.hparams:
231-
attr = model.hparams[attribute]
232-
elif hasattr(model, 'hparams') and hasattr(model.hparams, attribute):
233-
attr = getattr(model.hparams, attribute)
234-
# Check if the attribute in datamodule (datamodule gets registered in Trainer)
235-
elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
236-
attr = getattr(trainer.datamodule, attribute)
237-
else:
238-
raise ValueError(
239-
f'The {attribute} is neither stored in the model namespace nor the `hparams` namespace/dict,'
240-
' nor the datamodule.'
241-
)
242-
return attr
246+
if isinstance(holder, dict):
247+
return holder[attribute]
248+
return getattr(holder, attribute)
243249

244250

245251
def lightning_setattr(model, attribute, value):
246252
""" Special setattr for lightning. Checks for attribute in model namespace
247253
and the old hparams namespace/dict.
248254
Will also set the attribute on datamodule, if it exists.
249255
"""
250-
if not lightning_hasattr(model, attribute):
251-
raise ValueError(
252-
f'The {attribute} is neither stored in the model namespace nor the `hparams` namespace/dict,'
253-
' nor the datamodule.'
254-
)
255-
256-
trainer = getattr(model, 'trainer', None)
257-
258-
# Check if attribute in model
259-
if hasattr(model, attribute):
260-
setattr(model, attribute, value)
261-
262-
# Check if attribute in model.hparams, either namespace or dict
263-
elif hasattr(model, 'hparams'):
264-
if isinstance(model.hparams, dict):
265-
model.hparams[attribute] = value
256+
holders = lightning_get_all_attr_holders(model, attribute)
257+
if len(holders) == 0:
258+
raise ValueError(f'{attribute} is neither stored in the model namespace'
259+
' nor the `hparams` namespace/dict, nor the datamodule.')
260+
261+
for holder in holders:
262+
if isinstance(holder, dict):
263+
holder[attribute] = value
266264
else:
267-
setattr(model.hparams, attribute, value)
268-
269-
# Check if the attribute in datamodule (datamodule gets registered in Trainer)
270-
if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
271-
setattr(trainer.datamodule, attribute, value)
265+
setattr(holder, attribute, value)

tests/utilities/test_parsing.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ def _get_test_cases():
1919
class TestHparamsNamespace:
2020
learning_rate = 1
2121

22+
def __contains__(self, item):
23+
return item == "learning_rate"
24+
2225
TestHparamsDict = {'learning_rate': 2}
2326

2427
class TestModel1: # test for namespace
@@ -52,12 +55,26 @@ class TestModel5: # test for datamodule
5255

5356
model5 = TestModel5()
5457

55-
return model1, model2, model3, model4, model5
58+
class TestModel6: # test for datamodule w/ hparams w/o attribute (should use datamodule)
59+
trainer = Trainer
60+
hparams = TestHparamsDict
61+
62+
model6 = TestModel6()
63+
64+
TestHparamsDict2 = {'batch_size': 2}
65+
66+
class TestModel7: # test for datamodule w/ hparams w/ attribute (should use datamodule)
67+
trainer = Trainer
68+
hparams = TestHparamsDict2
69+
70+
model7 = TestModel7()
71+
72+
return model1, model2, model3, model4, model5, model6, model7
5673

5774

5875
def test_lightning_hasattr(tmpdir):
5976
""" Test that the lightning_hasattr works in all cases"""
60-
model1, model2, model3, model4, model5 = _get_test_cases()
77+
model1, model2, model3, model4, model5, model6, model7 = _get_test_cases()
6178
assert lightning_hasattr(model1, 'learning_rate'), \
6279
'lightning_hasattr failed to find namespace variable'
6380
assert lightning_hasattr(model2, 'learning_rate'), \
@@ -68,6 +85,10 @@ def test_lightning_hasattr(tmpdir):
6885
'lightning_hasattr found variable when it should not'
6986
assert lightning_hasattr(model5, 'batch_size'), \
7087
'lightning_hasattr failed to find batch_size in datamodule'
88+
assert lightning_hasattr(model6, 'batch_size'), \
89+
'lightning_hasattr failed to find batch_size in datamodule w/ hparams present'
90+
assert lightning_hasattr(model7, 'batch_size'), \
91+
'lightning_hasattr failed to find batch_size in hparams w/ datamodule present'
7192

7293

7394
def test_lightning_getattr(tmpdir):
@@ -77,9 +98,13 @@ def test_lightning_getattr(tmpdir):
7798
value = lightning_getattr(m, 'learning_rate')
7899
assert value == i, 'attribute not correctly extracted'
79100

80-
model5 = models[4]
101+
model5, model6, model7 = models[4:]
81102
assert lightning_getattr(model5, 'batch_size') == 8, \
82103
'batch_size not correctly extracted'
104+
assert lightning_getattr(model6, 'batch_size') == 8, \
105+
'batch_size not correctly extracted'
106+
assert lightning_getattr(model7, 'batch_size') == 8, \
107+
'batch_size not correctly extracted'
83108

84109

85110
def test_lightning_setattr(tmpdir):
@@ -90,7 +115,13 @@ def test_lightning_setattr(tmpdir):
90115
assert lightning_getattr(m, 'learning_rate') == 10, \
91116
'attribute not correctly set'
92117

93-
model5 = models[4]
118+
model5, model6, model7 = models[4:]
94119
lightning_setattr(model5, 'batch_size', 128)
120+
lightning_setattr(model6, 'batch_size', 128)
121+
lightning_setattr(model7, 'batch_size', 128)
95122
assert lightning_getattr(model5, 'batch_size') == 128, \
96123
'batch_size not correctly set'
124+
assert lightning_getattr(model6, 'batch_size') == 128, \
125+
'batch_size not correctly set'
126+
assert lightning_getattr(model7, 'batch_size') == 128, \
127+
'batch_size not correctly set'

0 commit comments

Comments
 (0)