Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions tests/utilities/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
unpicklable_function = lambda: None


@pytest.fixture(scope="module")
def model_cases():
class TestHparamsNamespace:
learning_rate = 1
Expand Down Expand Up @@ -93,9 +92,9 @@ class TestModel7: # test for datamodule w/ hparams w/ attribute (should use dat
return model1, model2, model3, model4, model5, model6, model7


def test_lightning_hasattr(tmpdir, model_cases):
def test_lightning_hasattr(tmpdir):
"""Test that the lightning_hasattr works in all cases."""
model1, model2, model3, model4, model5, model6, model7 = models = model_cases
model1, model2, model3, model4, model5, model6, model7 = models = model_cases()
assert lightning_hasattr(model1, "learning_rate"), "lightning_hasattr failed to find namespace variable"
assert lightning_hasattr(model2, "learning_rate"), "lightning_hasattr failed to find hparams namespace variable"
assert lightning_hasattr(model3, "learning_rate"), "lightning_hasattr failed to find hparams dict variable"
Expand All @@ -112,9 +111,9 @@ def test_lightning_hasattr(tmpdir, model_cases):
assert not lightning_hasattr(m, "this_attr_not_exist")


def test_lightning_getattr(tmpdir, model_cases):
def test_lightning_getattr(tmpdir):
"""Test that the lightning_getattr works in all cases."""
models = model_cases
models = model_cases()
for i, m in enumerate(models[:3]):
value = lightning_getattr(m, "learning_rate")
assert value == i, "attribute not correctly extracted"
Expand All @@ -132,9 +131,9 @@ def test_lightning_getattr(tmpdir, model_cases):
lightning_getattr(m, "this_attr_not_exist")


def test_lightning_setattr(tmpdir, model_cases):
def test_lightning_setattr(tmpdir):
"""Test that the lightning_setattr works in all cases."""
models = model_cases
models = model_cases()
for m in models[:3]:
lightning_setattr(m, "learning_rate", 10)
assert lightning_getattr(m, "learning_rate") == 10, "attribute not correctly set"
Expand Down