diff --git a/tests/utilities/test_parsing.py b/tests/utilities/test_parsing.py index 4754c8a620383..2fbed03de84c2 100644 --- a/tests/utilities/test_parsing.py +++ b/tests/utilities/test_parsing.py @@ -35,7 +35,6 @@ unpicklable_function = lambda: None -@pytest.fixture(scope="module") def model_cases(): class TestHparamsNamespace: learning_rate = 1 @@ -93,9 +92,9 @@ class TestModel7: # test for datamodule w/ hparams w/ attribute (should use dat return model1, model2, model3, model4, model5, model6, model7 -def test_lightning_hasattr(tmpdir, model_cases): +def test_lightning_hasattr(tmpdir): """Test that the lightning_hasattr works in all cases.""" - model1, model2, model3, model4, model5, model6, model7 = models = model_cases + model1, model2, model3, model4, model5, model6, model7 = models = model_cases() assert lightning_hasattr(model1, "learning_rate"), "lightning_hasattr failed to find namespace variable" assert lightning_hasattr(model2, "learning_rate"), "lightning_hasattr failed to find hparams namespace variable" assert lightning_hasattr(model3, "learning_rate"), "lightning_hasattr failed to find hparams dict variable" @@ -112,9 +111,9 @@ def test_lightning_hasattr(tmpdir, model_cases): assert not lightning_hasattr(m, "this_attr_not_exist") -def test_lightning_getattr(tmpdir, model_cases): +def test_lightning_getattr(tmpdir): """Test that the lightning_getattr works in all cases.""" - models = model_cases + models = model_cases() for i, m in enumerate(models[:3]): value = lightning_getattr(m, "learning_rate") assert value == i, "attribute not correctly extracted" @@ -132,9 +131,9 @@ def test_lightning_getattr(tmpdir, model_cases): lightning_getattr(m, "this_attr_not_exist") -def test_lightning_setattr(tmpdir, model_cases): +def test_lightning_setattr(tmpdir): """Test that the lightning_setattr works in all cases.""" - models = model_cases + models = model_cases() for m in models[:3]: lightning_setattr(m, "learning_rate", 10) assert lightning_getattr(m, "learning_rate") == 10, "attribute not correctly set"