|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | import pickle |
15 | | -from argparse import ArgumentParser |
| 15 | +from argparse import ArgumentParser, Namespace |
| 16 | +from dataclasses import dataclass |
16 | 17 | from typing import Any, Dict |
17 | 18 | from unittest import mock |
18 | 19 | from unittest.mock import call, PropertyMock |
19 | 20 |
|
20 | 21 | import pytest |
21 | 22 | import torch |
| 23 | +from omegaconf import OmegaConf |
22 | 24 |
|
23 | 25 | from pytorch_lightning import LightningDataModule, Trainer |
24 | 26 | from pytorch_lightning.callbacks import ModelCheckpoint |
25 | 27 | from pytorch_lightning.utilities import AttributeDict |
| 28 | +from pytorch_lightning.utilities.exceptions import MisconfigurationException |
26 | 29 | from pytorch_lightning.utilities.model_helpers import is_overridden |
27 | 30 | from tests.helpers import BoringDataModule, BoringModel |
28 | 31 | from tests.helpers.datamodules import ClassifDataModule |
@@ -532,12 +535,101 @@ def test_dm_init_from_datasets_dataloaders(iterable): |
532 | 535 | ) |
533 | 536 |
|
534 | 537 |
|
535 | | -class DataModuleWithHparams(LightningDataModule): |
| 538 | +# all args |
| 539 | +class DataModuleWithHparams_0(LightningDataModule): |
536 | 540 | def __init__(self, arg0, arg1, kwarg0=None): |
537 | 541 | super().__init__() |
538 | 542 | self.save_hyperparameters() |
539 | 543 |
|
540 | 544 |
|
541 | | -def test_simple_hyperparameters_saving(): |
542 | | - data = DataModuleWithHparams(10, "foo", kwarg0="bar") |
| 545 | +# single arg |
| 546 | +class DataModuleWithHparams_1(LightningDataModule): |
| 547 | + def __init__(self, arg0, *args, **kwargs): |
| 548 | + super().__init__() |
| 549 | + self.save_hyperparameters(arg0) |
| 550 | + |
| 551 | + |
| 552 | +def test_hyperparameters_saving(): |
| 553 | + data = DataModuleWithHparams_0(10, "foo", kwarg0="bar") |
543 | 554 | assert data.hparams == AttributeDict({"arg0": 10, "arg1": "foo", "kwarg0": "bar"}) |
| 555 | + |
| 556 | + data = DataModuleWithHparams_1(Namespace(**{"hello": "world"}), "foo", kwarg0="bar") |
| 557 | + assert data.hparams == AttributeDict({"hello": "world"}) |
| 558 | + |
| 559 | + data = DataModuleWithHparams_1({"hello": "world"}, "foo", kwarg0="bar") |
| 560 | + assert data.hparams == AttributeDict({"hello": "world"}) |
| 561 | + |
| 562 | + data = DataModuleWithHparams_1(OmegaConf.create({"hello": "world"}), "foo", kwarg0="bar") |
| 563 | + assert data.hparams == OmegaConf.create({"hello": "world"}) |
| 564 | + |
| 565 | + |
| 566 | +def test_define_as_dataclass(): |
| 567 | + # makes sure that no functionality is broken and the user can still manually make |
| 568 | + # super().__init__ call with parameters |
| 569 | + # also tests all the dataclass features that can be enabled without breaking anything |
| 570 | + @dataclass(init=True, repr=True, eq=True, order=True, unsafe_hash=True, frozen=False) |
| 571 | + class BoringDataModule1(LightningDataModule): |
| 572 | + batch_size: int |
| 573 | + dims: int = 2 |
| 574 | + |
| 575 | + def __post_init__(self): |
| 576 | + super().__init__(dims=self.dims) |
| 577 | + |
| 578 | + # asserts for the different dunder methods added by dataclass, when __init__ is implemented, i.e. |
| 579 | + # __repr__, __eq__, __lt__, __le__, etc. |
| 580 | + assert BoringDataModule1(batch_size=64).dims == 2 |
| 581 | + assert BoringDataModule1(batch_size=32) |
| 582 | + assert hasattr(BoringDataModule1, "__repr__") |
| 583 | + assert BoringDataModule1(batch_size=32) == BoringDataModule1(batch_size=32) |
| 584 | + |
| 585 | + # asserts inherent calling of super().__init__ in case user doesn't make the call |
| 586 | + @dataclass |
| 587 | + class BoringDataModule2(LightningDataModule): |
| 588 | + batch_size: int |
| 589 | + |
| 590 | + # asserts for the different dunder methods added by dataclass, when super class is inherently initialized, i.e. |
| 591 | + # __init__, __repr__, __eq__, __lt__, __le__, etc. |
| 592 | + assert BoringDataModule2(batch_size=32) |
| 593 | + assert hasattr(BoringDataModule2, "__repr__") |
| 594 | + assert BoringDataModule2(batch_size=32).prepare_data() is None |
| 595 | + assert BoringDataModule2(batch_size=32) == BoringDataModule2(batch_size=32) |
| 596 | + |
| 597 | + # checking for all the different multilevel inhertiance scenarios, for init call on LightningDataModule |
| 598 | + @dataclass |
| 599 | + class BoringModuleBase1(LightningDataModule): |
| 600 | + num_features: int |
| 601 | + |
| 602 | + class BoringModuleBase2(LightningDataModule): |
| 603 | + def __init__(self, num_features: int): |
| 604 | + self.num_features = num_features |
| 605 | + |
| 606 | + @dataclass |
| 607 | + class BoringModuleDerived1(BoringModuleBase1): |
| 608 | + ... |
| 609 | + |
| 610 | + class BoringModuleDerived2(BoringModuleBase1): |
| 611 | + def __init__(self): |
| 612 | + ... |
| 613 | + |
| 614 | + @dataclass |
| 615 | + class BoringModuleDerived3(BoringModuleBase2): |
| 616 | + ... |
| 617 | + |
| 618 | + class BoringModuleDerived4(BoringModuleBase2): |
| 619 | + def __init__(self): |
| 620 | + ... |
| 621 | + |
| 622 | + assert hasattr(BoringModuleDerived1(num_features=2), "_has_prepared_data") |
| 623 | + assert hasattr(BoringModuleDerived2(), "_has_prepared_data") |
| 624 | + assert hasattr(BoringModuleDerived3(), "_has_prepared_data") |
| 625 | + assert hasattr(BoringModuleDerived4(), "_has_prepared_data") |
| 626 | + |
| 627 | + |
| 628 | +def test_inconsistent_prepare_data_per_node(tmpdir): |
| 629 | + with pytest.raises(MisconfigurationException, match="Inconsistent settings found for `prepare_data_per_node`."): |
| 630 | + model = BoringModel() |
| 631 | + dm = BoringDataModule() |
| 632 | + trainer = Trainer(prepare_data_per_node=False) |
| 633 | + trainer.model = model |
| 634 | + trainer.datamodule = dm |
| 635 | + trainer.data_connector.prepare_data() |
0 commit comments