Skip to content
Merged
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed metrics states being overridden in ddp mode ([#4482](https://github.com/PyTorchLightning/pytorch-lightning/pull/4482))


- Fixed `lightning_getattr`, `lightning_hasattr` not finding the correct attributes in datamodule ([#4347](https://github.com/PyTorchLightning/pytorch-lightning/pull/4347))


## [1.0.5] - 2020-11-03

### Added
Expand Down
23 changes: 10 additions & 13 deletions pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,9 @@ def __repr__(self):
def lightning_hasattr(model, attribute):
""" Special hasattr for lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule. """
trainer = model.trainer
trainer = getattr(model, 'trainer', None)

attr = False
# Check if attribute in model
if hasattr(model, attribute):
attr = True
Expand All @@ -189,29 +190,25 @@ def lightning_hasattr(model, attribute):
else:
attr = hasattr(model.hparams, attribute)
# Check if the attribute in datamodule (datamodule gets registered in Trainer)
elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
attr = getattr(trainer.datamodule, attribute)
else:
attr = False
if not attr and trainer is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if the Trainer does not have trainer.datamodule?
UPDATE: I found the assignment in data connector, but missing in Trainer __init__
cc: @tchaton

attr = hasattr(trainer.datamodule, attribute)

return attr


def lightning_getattr(model, attribute):
""" Special getattr for lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule. """
trainer = model.trainer
trainer = getattr(model, 'trainer', None)

# Check if attribute in model
if hasattr(model, attribute):
attr = getattr(model, attribute)
# Check if attribute in model.hparams, either namespace or dict
elif hasattr(model, 'hparams'):
if isinstance(model.hparams, dict):
attr = model.hparams[attribute]
else:
attr = getattr(model.hparams, attribute)

elif hasattr(model, 'hparams') and isinstance(model.hparams, dict) and attribute in model.hparams:
attr = model.hparams[attribute]
elif hasattr(model, 'hparams') and hasattr(model.hparams, attribute):
attr = getattr(model.hparams, attribute)
# Check if the attribute in datamodule (datamodule gets registered in Trainer)
elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
attr = getattr(trainer.datamodule, attribute)
Expand All @@ -230,7 +227,7 @@ def lightning_setattr(model, attribute, value):
raise ValueError(f'{attribute} is neither stored in the model namespace'
' nor the `hparams` namespace/dict, nor the datamodule.')

trainer = model.trainer
trainer = getattr(model, 'trainer', None)

# Check if attribute in model
if hasattr(model, attribute):
Expand Down
27 changes: 25 additions & 2 deletions tests/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class TestHparamsNamespace:

class TestModel1: # test for namespace
learning_rate = 0

model1 = TestModel1()

class TestModel2: # test for hparams namespace
Expand All @@ -41,12 +42,23 @@ class TestModel4: # fail case

model4 = TestModel4()

return model1, model2, model3, model4
class DataModule:
batch_size = 8

class Trainer:
datamodule = DataModule

class TestModel5: # test for datamodule
trainer = Trainer

model5 = TestModel5()

return model1, model2, model3, model4, model5


def test_lightning_hasattr(tmpdir):
""" Test that the lightning_hasattr works in all cases"""
model1, model2, model3, model4 = _get_test_cases()
model1, model2, model3, model4, model5 = _get_test_cases()
assert lightning_hasattr(model1, 'learning_rate'), \
'lightning_hasattr failed to find namespace variable'
assert lightning_hasattr(model2, 'learning_rate'), \
Expand All @@ -55,6 +67,8 @@ def test_lightning_hasattr(tmpdir):
'lightning_hasattr failed to find hparams dict variable'
assert not lightning_hasattr(model4, 'learning_rate'), \
'lightning_hasattr found variable when it should not'
assert lightning_hasattr(model5, 'batch_size'), \
'lightning_hasattr failed to find batch_size in datamodule'


def test_lightning_getattr(tmpdir):
Expand All @@ -64,6 +78,10 @@ def test_lightning_getattr(tmpdir):
value = lightning_getattr(m, 'learning_rate')
assert value == i, 'attribute not correctly extracted'

model5 = models[4]
assert lightning_getattr(model5, 'batch_size') == 8, \
'batch_size not correctly extracted'


def test_lightning_setattr(tmpdir):
""" Test that the lightning_setattr works in all cases"""
Expand All @@ -72,3 +90,8 @@ def test_lightning_setattr(tmpdir):
lightning_setattr(m, 'learning_rate', 10)
assert lightning_getattr(m, 'learning_rate') == 10, \
'attribute not correctly set'

model5 = models[4]
lightning_setattr(model5, 'batch_size', 128)
assert lightning_getattr(model5, 'batch_size') == 128, \
'batch_size not correctly set'