From 3615f1a51149d81f1f3aa0156f8206f29157072c Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 23 Feb 2021 00:21:45 +0530 Subject: [PATCH 1/8] remove warning --- pytorch_lightning/overrides/base.py | 2 -- tests/overrides/test_data_parallel.py | 49 --------------------------- 2 files changed, 51 deletions(-) diff --git a/pytorch_lightning/overrides/base.py b/pytorch_lightning/overrides/base.py index 170cdc4600bb4..34eabf42c1a07 100644 --- a/pytorch_lightning/overrides/base.py +++ b/pytorch_lightning/overrides/base.py @@ -57,11 +57,9 @@ def forward(self, *inputs, **kwargs): elif trainer and trainer.testing: output = self.module.test_step(*inputs, **kwargs) - warn_if_output_is_none(output, "test_step") elif trainer and (trainer.sanity_checking or trainer.validating): output = self.module.validation_step(*inputs, **kwargs) - warn_if_output_is_none(output, "validation_step") elif trainer and trainer.predicting: output = self.module.predict(*inputs, **kwargs) diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py index 5008ec798f7ec..e8aaefc2f973c 100644 --- a/tests/overrides/test_data_parallel.py +++ b/tests/overrides/test_data_parallel.py @@ -44,55 +44,6 @@ def test_lightning_wrapper_module_methods(wrapper_class, stage): getattr(pl_module, step).assert_called_with(batch, batch_idx) -@pytest.mark.parametrize("wrapper_class", [ - LightningParallelModule, - LightningDistributedModule, -]) -@pytest.mark.parametrize("stage", [ - ("training", "training_step"), - ("testing", "test_step"), - ("validating", "validation_step"), -]) -def test_lightning_wrapper_module_warn_none_output(wrapper_class, stage): - """ Test that the LightningWrapper module warns about forgotten return statement. """ - warning_cache.clear() - pl_module = MagicMock() - - prop, step = stage - pl_module.trainer.sanity_checking = False - for p in ("training", "testing", "validating", "predicting"): - setattr(pl_module.trainer, p, p == prop) - - wrapped_module = wrapper_class(pl_module) - - getattr(pl_module, step).return_value = None - - with pytest.warns(UserWarning, match=f"Your {step} returned None"): - wrapped_module() - - -@pytest.mark.parametrize("wrapper_class", [ - LightningParallelModule, - LightningDistributedModule, -]) -def test_lightning_wrapper_module_no_warn(wrapper_class): - warning_cache.clear() - pl_module = MagicMock() - - pl_module.trainer.sanity_checking = False - pl_module.trainer.training = False - pl_module.trainer.testing = False - pl_module.trainer.validating = False - pl_module.trainer.predicting = False - - wrapped_module = wrapper_class(pl_module) - - with pytest.warns(None) as record: - wrapped_module() - pl_module.assert_called() - assert not record - - @pytest.mark.parametrize( "inp,expected", [ [torch.tensor(1.0), torch.tensor([1.0])], From 275ac56c4dcc7b07a9c49e9849ff0dea296ecd8c Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 23 Feb 2021 00:32:08 +0530 Subject: [PATCH 2/8] auto_opt --- pytorch_lightning/overrides/base.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/pytorch_lightning/overrides/base.py b/pytorch_lightning/overrides/base.py index 34eabf42c1a07..9bdd1fcb1dc6e 100644 --- a/pytorch_lightning/overrides/base.py +++ b/pytorch_lightning/overrides/base.py @@ -51,20 +51,14 @@ def forward(self, *inputs, **kwargs): # it is done manually in ``LightningModule.manual_backward`` # `require_backward_grad_sync` will be reset in the # ddp_plugin ``post_training_step`` hook - if not self.module.automatic_optimization: + if self.module.automatic_optimization: trainer.model.require_backward_grad_sync = False - warn_if_output_is_none(output, "training_step") - elif trainer and trainer.testing: output = self.module.test_step(*inputs, **kwargs) - elif trainer and (trainer.sanity_checking or trainer.validating): output = self.module.validation_step(*inputs, **kwargs) - elif trainer and trainer.predicting: output = self.module.predict(*inputs, **kwargs) - warn_if_output_is_none(output, "predict") - else: output = self.module(*inputs, **kwargs) From 857d42a1a63a798514fd03203c79045574d3be1e Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 23 Feb 2021 00:35:28 +0530 Subject: [PATCH 3/8] chlog --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f8f7a08b089b..8bc2a20b6f39b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) +<<<<<<< HEAD - Added `RunningStage.SANITY_CHECKING` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) @@ -27,6 +28,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274)) +- Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139)) + + ### Changed - Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259)) @@ -49,6 +53,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164)) +- Removed no return warning from val/test step ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139)) + + - Removed passing a `ModelCheckpoint` instance to `Trainer(checkpoint_callback)` ([#6166](https://github.com/PyTorchLightning/pytorch-lightning/pull/6166)) From 9028507c2819db6d61712178bc8a7b9afd0de97a Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 23 Feb 2021 22:54:29 +0530 Subject: [PATCH 4/8] auto_opt --- tests/overrides/test_data_parallel.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py index e8aaefc2f973c..610d53d3af140 100644 --- a/tests/overrides/test_data_parallel.py +++ b/tests/overrides/test_data_parallel.py @@ -40,7 +40,6 @@ def test_lightning_wrapper_module_methods(wrapper_class, stage): setattr(pl_module.trainer, p, p == prop) wrapped_module(batch, batch_idx) - getattr(pl_module, step).assert_called_with(batch, batch_idx) From 35e22065f23a8ca536c65d584b80aabe1941d0a9 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 23 Feb 2021 23:32:25 +0530 Subject: [PATCH 5/8] no_warning_call --- tests/helpers/utils.py | 2 +- tests/overrides/test_data_parallel.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 5a7062829d738..f5c1726a423bb 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -124,7 +124,7 @@ def no_warning_call(warning_type, match: Optional[str] = None): try: w = record.pop(warning_type) - if not ((match and match in w.text) or w): + if not (match and match in str(w.message)): return except AssertionError: # no warning raised diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py index 610d53d3af140..c323e3676cf09 100644 --- a/tests/overrides/test_data_parallel.py +++ b/tests/overrides/test_data_parallel.py @@ -13,7 +13,11 @@ ) from pytorch_lightning.trainer.states import RunningStage from tests.helpers import BoringModel +<<<<<<< HEAD from tests.helpers.runif import RunIf +======= +from tests.helpers.utils import no_warning_call +>>>>>>> no_warning_call @pytest.mark.parametrize("wrapper_class", [ @@ -36,6 +40,7 @@ def test_lightning_wrapper_module_methods(wrapper_class, stage): prop, step = stage pl_module.trainer.sanity_checking = False + for p in ("training", "testing", "validating", "predicting"): setattr(pl_module.trainer, p, p == prop) From 241fbc1438e690274fe9e6453af788a64917b483 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 26 Feb 2021 23:38:36 +0530 Subject: [PATCH 6/8] rm old code --- pytorch_lightning/overrides/base.py | 11 --------- tests/overrides/test_data_parallel.py | 5 ---- .../test_train_loop_flow_scalar_1_0.py | 24 ++++++++++++------- 3 files changed, 16 insertions(+), 24 deletions(-) diff --git a/pytorch_lightning/overrides/base.py b/pytorch_lightning/overrides/base.py index 9bdd1fcb1dc6e..46740e78dd2df 100644 --- a/pytorch_lightning/overrides/base.py +++ b/pytorch_lightning/overrides/base.py @@ -11,17 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any - import torch from torch.nn import DataParallel from torch.nn.parallel import DistributedDataParallel from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin -from pytorch_lightning.utilities.warnings import WarningCache - -warning_cache = WarningCache() class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module): @@ -68,12 +63,6 @@ def on_post_move_to_device(self): pass -def warn_if_output_is_none(output: Any, method_name: str) -> None: - """ Warns user about which method returned None. """ - if output is None: - warning_cache.warn(f'Your {method_name} returned None. Did you forget to return an output?') - - def unwrap_lightning_module(wrapped_model) -> LightningModule: model = wrapped_model if isinstance(model, (DistributedDataParallel, DataParallel)): diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py index c323e3676cf09..4a6778da30654 100644 --- a/tests/overrides/test_data_parallel.py +++ b/tests/overrides/test_data_parallel.py @@ -5,7 +5,6 @@ from torch.nn import DataParallel from pytorch_lightning.overrides import LightningDistributedModule -from pytorch_lightning.overrides.base import warning_cache from pytorch_lightning.overrides.data_parallel import ( LightningParallelModule, python_scalar_to_tensor, @@ -13,11 +12,7 @@ ) from pytorch_lightning.trainer.states import RunningStage from tests.helpers import BoringModel -<<<<<<< HEAD from tests.helpers.runif import RunIf -======= -from tests.helpers.utils import no_warning_call ->>>>>>> no_warning_call @pytest.mark.parametrize("wrapper_class", [ diff --git a/tests/trainer/data_flow/test_train_loop_flow_scalar_1_0.py b/tests/trainer/data_flow/test_train_loop_flow_scalar_1_0.py index 0eec3c18cda83..d5a4da79942ed 100644 --- a/tests/trainer/data_flow/test_train_loop_flow_scalar_1_0.py +++ b/tests/trainer/data_flow/test_train_loop_flow_scalar_1_0.py @@ -24,6 +24,7 @@ from pytorch_lightning.core.lightning import LightningModule from tests.helpers.boring_model import BoringModel from tests.helpers.deterministic_model import DeterministicModel +from tests.helpers.utils import no_warning_call @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @@ -211,7 +212,8 @@ def backward(self, loss, optimizer, optimizer_idx): def test_train_step_no_return(tmpdir): """ - Tests that only training_step can be used + Tests that only training_step raises a warning when + nothing is returned in case of automatic_optimization """ class TestModel(BoringModel): @@ -231,20 +233,26 @@ def validation_epoch_end(self, outputs): assert len(outputs) == 0 model = TestModel() - trainer = Trainer( + trainer_args = dict( default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, - max_epochs=2, - log_every_n_steps=1, - weights_summary=None, + fast_dev_run=2, ) - with pytest.warns(UserWarning, match=r'.*training_step returned None.*'): + trainer = Trainer(**trainer_args) + + with pytest.warns(UserWarning, match=r'training_step returned None .*'): trainer.fit(model) + assert model.training_step_called assert model.validation_step_called + model = TestModel() + model.automatic_optimization = False + trainer = Trainer(**trainer_args) + + with no_warning_call(UserWarning, match=r'training_step returned None .*'): + trainer.fit(model) + def test_training_step_no_return_when_even(tmpdir): """ From 5f08204760873944400a59f6088422687512d46d Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 26 Feb 2021 23:51:13 +0530 Subject: [PATCH 7/8] add warning for predict --- pytorch_lightning/trainer/predict_loop.py | 6 ++++++ tests/trainer/test_trainer.py | 21 +++++++++++++++++++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/predict_loop.py b/pytorch_lightning/trainer/predict_loop.py index 40507a1bc03f4..4fe6960055ca9 100644 --- a/pytorch_lightning/trainer/predict_loop.py +++ b/pytorch_lightning/trainer/predict_loop.py @@ -14,6 +14,7 @@ import torch from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.warnings import WarningCache class PredictLoop(object): @@ -22,6 +23,7 @@ def __init__(self, trainer): self.trainer = trainer self.max_batches = None self.num_dataloaders = None + self.warning_cache = WarningCache() def on_trainer_init(self): self.trainer.num_predict_batches = [] @@ -74,6 +76,10 @@ def predict(self, batch, batch_idx, dataloader_idx): model_ref._current_fx_name = "predict" predictions = self.trainer.accelerator.predict(args) + + if predictions is None: + self.warning_cache.warn("predict returned None if it was on purpose, ignore this warning...") + self._predictions[dataloader_idx].append(predictions) self.trainer._progress_bar_callback.on_predict_batch_end( self.trainer, model_ref, predictions, batch, batch_idx, dataloader_idx diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 3e090fb44943e..e359d2e0623dc 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1393,11 +1393,11 @@ def predict_dataloader(self): return self._dataloaders -def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=True): +def predict(tmpdir, accelerator, gpus, num_processes, model=None, plugins=None, datamodule=True): dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] - model = BoringModel() + model = model or BoringModel() datamodule = TestLightningDataModule(dataloaders) trainer = Trainer( @@ -1422,6 +1422,23 @@ def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=T assert results[0][0].shape == torch.Size([1, 2]) +def test_trainer_predict_no_return(tmpdir): + """ + Test trainer.predict warns when nothing is returned + """ + + class CustomBoringModel(BoringModel): + + def predict(self, batch, batch_idx, dataloader_idx=None): + if (batch_idx + 1) % 2 == 0: + return + + return super().predict(batch, batch_idx, dataloader_idx) + + with pytest.warns(UserWarning, match='predict returned None'): + predict(tmpdir, None, None, 1, model=CustomBoringModel()) + + @pytest.mark.parametrize('datamodule', [False, True]) def test_trainer_predict_cpu(tmpdir, datamodule): predict(tmpdir, None, None, 1, datamodule=datamodule) From d5584e1d20aa2bd1c1f83358564f8fbb880b9fee Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 6 Mar 2021 22:15:44 +0530 Subject: [PATCH 8/8] rebase --- CHANGELOG.md | 1 - pytorch_lightning/overrides/base.py | 2 +- .../plugins/training_type/ddp_spawn.py | 3 +-- .../plugins/training_type/tpu_spawn.py | 3 +-- .../logger_connector/logger_connector.py | 4 +--- pytorch_lightning/trainer/evaluation_loop.py | 3 +-- pytorch_lightning/trainer/trainer.py | 3 +-- tests/overrides/test_data_parallel.py | 14 ++++++++------ tests/trainer/test_states.py | 2 ++ tests/utilities/test_parsing.py | 12 ++++++++---- 10 files changed, 24 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8bc2a20b6f39b..41ec984da3c88 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) -<<<<<<< HEAD - Added `RunningStage.SANITY_CHECKING` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) diff --git a/pytorch_lightning/overrides/base.py b/pytorch_lightning/overrides/base.py index 46740e78dd2df..1d6f4e93b5779 100644 --- a/pytorch_lightning/overrides/base.py +++ b/pytorch_lightning/overrides/base.py @@ -46,7 +46,7 @@ def forward(self, *inputs, **kwargs): # it is done manually in ``LightningModule.manual_backward`` # `require_backward_grad_sync` will be reset in the # ddp_plugin ``post_training_step`` hook - if self.module.automatic_optimization: + if not self.module.automatic_optimization: trainer.model.require_backward_grad_sync = False elif trainer and trainer.testing: output = self.module.test_step(*inputs, **kwargs) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 3dace06cbf825..9f90ca2cf825b 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -214,8 +214,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): # save the last weights last_path = None if ( - self.lightning_module.trainer.state == TrainerState.FITTING - and best_model_path is not None + self.lightning_module.trainer.state == TrainerState.FITTING and best_model_path is not None and len(best_model_path) > 0 ): last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index efada181ca9a6..0bea6ed56c5ab 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -139,8 +139,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): # save the last weights last_path = None if ( - self.lightning_module.trainer.state == TrainerState.FITTING - and best_model_path is not None + self.lightning_module.trainer.state == TrainerState.FITTING and best_model_path is not None and len(best_model_path) > 0 ): last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 2c6a0d613e648..15428c5d5c248 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -297,9 +297,7 @@ def get_evaluate_epoch_results(self): # log results of evaluation if ( - self.trainer.state != TrainerState.FITTING - and self.trainer.evaluating - and self.trainer.is_global_zero + self.trainer.state != TrainerState.FITTING and self.trainer.evaluating and self.trainer.is_global_zero and self.trainer.verbose_evaluate ): print('-' * 80) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index d5047ce57858a..91cfc2ec757d5 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -60,8 +60,7 @@ def get_evaluation_dataloaders(self): self.trainer.reset_val_dataloader(model) if self.trainer.sanity_checking: self.trainer.num_sanity_val_batches = [ - min(self.trainer.num_sanity_val_steps, val_batches) - for val_batches in self.trainer.num_val_batches + min(self.trainer.num_sanity_val_steps, val_batches) for val_batches in self.trainer.num_val_batches ] max_batches = self.trainer.num_sanity_val_batches else: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cc1964f07039b..dd0cd8c627965 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -875,8 +875,7 @@ def test( # Attach datamodule to get setup/prepare_data added to model before the call to it below self.data_connector.attach_datamodule(model, datamodule) results = ( - self.__evaluate_given_model(model, dataloaders=test_dataloaders) - if model_provided else + self.__evaluate_given_model(model, dataloaders=test_dataloaders) if model_provided else self.__evaluate_using_weights(model, ckpt_path=ckpt_path, dataloaders=test_dataloaders) ) diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py index 4a6778da30654..3921e7ef33b8e 100644 --- a/tests/overrides/test_data_parallel.py +++ b/tests/overrides/test_data_parallel.py @@ -19,12 +19,14 @@ LightningParallelModule, LightningDistributedModule, ]) -@pytest.mark.parametrize("stage", [ - ("training", "training_step"), - ("testing", "test_step"), - ("validating", "validation_step"), - ("predicting", "predict"), -]) +@pytest.mark.parametrize( + "stage", [ + ("training", "training_step"), + ("testing", "test_step"), + ("validating", "validation_step"), + ("predicting", "predict"), + ] +) def test_lightning_wrapper_module_methods(wrapper_class, stage): """ Test that the LightningWrapper redirects .forward() to the LightningModule methods. """ pl_module = MagicMock() diff --git a/tests/trainer/test_states.py b/tests/trainer/test_states.py index bedaef6d1ffb8..d2257a84f74db 100644 --- a/tests/trainer/test_states.py +++ b/tests/trainer/test_states.py @@ -34,6 +34,7 @@ def test_trainer_state_while_running(tmpdir, extra_params): trainer = Trainer(default_root_dir=tmpdir, **extra_params, auto_lr_find=True) class TestModel(BoringModel): + def __init__(self, expected_state): super().__init__() self.expected_state = expected_state @@ -78,6 +79,7 @@ def test_interrupt_state_on_keyboard_interrupt(tmpdir, extra_params): model = BoringModel() class InterruptCallback(Callback): + def on_batch_start(self, trainer, pl_module): raise KeyboardInterrupt diff --git a/tests/utilities/test_parsing.py b/tests/utilities/test_parsing.py index 391b0d9c97f0d..6ea10adf3d696 100644 --- a/tests/utilities/test_parsing.py +++ b/tests/utilities/test_parsing.py @@ -233,7 +233,9 @@ class UnpicklableClass: def test_parse_class_init_keys(tmpdir): + class Class: + def __init__(self, hparams, *my_args, anykw=42, **my_kwargs): pass @@ -241,7 +243,9 @@ def __init__(self, hparams, *my_args, anykw=42, **my_kwargs): def test_get_init_args(tmpdir): + class AutomaticArgsModel: + def __init__(self, anyarg, anykw=42, **kwargs): super().__init__() @@ -259,7 +263,9 @@ def get_init_args_wrapper(self): def test_collect_init_args(): + class AutomaticArgsParent: + def __init__(self, anyarg, anykw=42, **kwargs): super().__init__() self.get_init_args_wrapper() @@ -269,6 +275,7 @@ def get_init_args_wrapper(self): self.result = collect_init_args(frame, []) class AutomaticArgsChild(AutomaticArgsParent): + def __init__(self, anyarg, childarg, anykw=42, childkw=42, **kwargs): super().__init__(anyarg, anykw=anykw, **kwargs) @@ -299,10 +306,7 @@ def test_attribute_dict(tmpdir): def test_flatten_dict(tmpdir): - d = { - '1': 1, - '_': {'2': 2, '_': {'3': 3, '4': 4}} - } + d = {'1': 1, '_': {'2': 2, '_': {'3': 3, '4': 4}}} expected = { '1': 1,