diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e49a59c79f94..bb286e82759c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -56,6 +56,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed metrics states being overridden in ddp mode ([#4482](https://github.com/PyTorchLightning/pytorch-lightning/pull/4482)) +- Fixed `lightning_getattr`, `lightning_hasattr` not finding the correct attributes in datamodule ([#4347](https://github.com/PyTorchLightning/pytorch-lightning/pull/4347)) + + ## [1.0.5] - 2020-11-03 ### Added diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index c8230205752d4..348eec110c3a1 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -177,8 +177,9 @@ def __repr__(self): def lightning_hasattr(model, attribute): """ Special hasattr for lightning. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """ - trainer = model.trainer + trainer = getattr(model, 'trainer', None) + attr = False # Check if attribute in model if hasattr(model, attribute): attr = True @@ -189,10 +190,8 @@ def lightning_hasattr(model, attribute): else: attr = hasattr(model.hparams, attribute) # Check if the attribute in datamodule (datamodule gets registered in Trainer) - elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute): - attr = getattr(trainer.datamodule, attribute) - else: - attr = False + if not attr and trainer is not None: + attr = hasattr(trainer.datamodule, attribute) return attr @@ -200,18 +199,16 @@ def lightning_hasattr(model, attribute): def lightning_getattr(model, attribute): """ Special getattr for lightning. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """ - trainer = model.trainer + trainer = getattr(model, 'trainer', None) # Check if attribute in model if hasattr(model, attribute): attr = getattr(model, attribute) # Check if attribute in model.hparams, either namespace or dict - elif hasattr(model, 'hparams'): - if isinstance(model.hparams, dict): - attr = model.hparams[attribute] - else: - attr = getattr(model.hparams, attribute) - + elif hasattr(model, 'hparams') and isinstance(model.hparams, dict) and attribute in model.hparams: + attr = model.hparams[attribute] + elif hasattr(model, 'hparams') and hasattr(model.hparams, attribute): + attr = getattr(model.hparams, attribute) # Check if the attribute in datamodule (datamodule gets registered in Trainer) elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute): attr = getattr(trainer.datamodule, attribute) @@ -230,7 +227,7 @@ def lightning_setattr(model, attribute, value): raise ValueError(f'{attribute} is neither stored in the model namespace' ' nor the `hparams` namespace/dict, nor the datamodule.') - trainer = model.trainer + trainer = getattr(model, 'trainer', None) # Check if attribute in model if hasattr(model, attribute): diff --git a/tests/utilities/parsing.py b/tests/utilities/parsing.py index 13cfeaa64b01a..056590f1a6d35 100644 --- a/tests/utilities/parsing.py +++ b/tests/utilities/parsing.py @@ -24,6 +24,7 @@ class TestHparamsNamespace: class TestModel1: # test for namespace learning_rate = 0 + model1 = TestModel1() class TestModel2: # test for hparams namespace @@ -41,12 +42,23 @@ class TestModel4: # fail case model4 = TestModel4() - return model1, model2, model3, model4 + class DataModule: + batch_size = 8 + + class Trainer: + datamodule = DataModule + + class TestModel5: # test for datamodule + trainer = Trainer + + model5 = TestModel5() + + return model1, model2, model3, model4, model5 def test_lightning_hasattr(tmpdir): """ Test that the lightning_hasattr works in all cases""" - model1, model2, model3, model4 = _get_test_cases() + model1, model2, model3, model4, model5 = _get_test_cases() assert lightning_hasattr(model1, 'learning_rate'), \ 'lightning_hasattr failed to find namespace variable' assert lightning_hasattr(model2, 'learning_rate'), \ @@ -55,6 +67,8 @@ def test_lightning_hasattr(tmpdir): 'lightning_hasattr failed to find hparams dict variable' assert not lightning_hasattr(model4, 'learning_rate'), \ 'lightning_hasattr found variable when it should not' + assert lightning_hasattr(model5, 'batch_size'), \ + 'lightning_hasattr failed to find batch_size in datamodule' def test_lightning_getattr(tmpdir): @@ -64,6 +78,10 @@ def test_lightning_getattr(tmpdir): value = lightning_getattr(m, 'learning_rate') assert value == i, 'attribute not correctly extracted' + model5 = models[4] + assert lightning_getattr(model5, 'batch_size') == 8, \ + 'batch_size not correctly extracted' + def test_lightning_setattr(tmpdir): """ Test that the lightning_setattr works in all cases""" @@ -72,3 +90,8 @@ def test_lightning_setattr(tmpdir): lightning_setattr(m, 'learning_rate', 10) assert lightning_getattr(m, 'learning_rate') == 10, \ 'attribute not correctly set' + + model5 = models[4] + lightning_setattr(model5, 'batch_size', 128) + assert lightning_getattr(model5, 'batch_size') == 128, \ + 'batch_size not correctly set'