Skip to content

Commit 40a8ebf

Browse files
committed
Fixes for LightningCLI
1 parent a4fb139 commit 40a8ebf

File tree

4 files changed

+7
-8
lines changed

4 files changed

+7
-8
lines changed

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def update_train_step_metrics(self) -> None:
228228
return
229229

230230
# TODO: remove this call in v1.7
231-
self._log_gpus_metrics()
231+
# self._log_gpus_metrics()
232232

233233
# when metrics should be logged
234234
assert not self._epoch_end_reached

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1415,7 +1415,7 @@ def call_hook(
14151415
# Rely on the accelerator output if lightningModule hook returns nothing
14161416
# Required for cases such as DataParallel where we reduce the output for the user
14171417
# todo: move this data parallel logic into the data parallel plugin
1418-
output = accelerator_output if output is None else output
1418+
output = ttp_output if output is None else output
14191419

14201420
# call the ttp hook
14211421
if hook_name not in ("setup", "teardown", "on_train_start") and hasattr(

tests/helpers/boring_model.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,18 +158,15 @@ def prepare_data(self):
158158
def setup(self, stage: Optional[str] = None):
159159
if stage == "fit" or stage is None:
160160
self.random_train = Subset(self.random_full, indices=range(64))
161-
self.dims = self.random_train[0].shape
162161

163162
if stage in ("fit", "validate") or stage is None:
164163
self.random_val = Subset(self.random_full, indices=range(64, 64 * 2))
165164

166165
if stage == "test" or stage is None:
167166
self.random_test = Subset(self.random_full, indices=range(64 * 2, 64 * 3))
168-
self.dims = getattr(self, "dims", self.random_test[0].shape)
169167

170168
if stage == "predict" or stage is None:
171169
self.random_predict = Subset(self.random_full, indices=range(64 * 3, 64 * 4))
172-
self.dims = getattr(self, "dims", self.random_predict[0].shape)
173170

174171
def train_dataloader(self):
175172
return DataLoader(self.random_train)

tests/utilities/test_cli.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ def add_arguments_to_parser(self, parser):
574574

575575
class EarlyExitTestModel(BoringModel):
576576
def on_fit_start(self):
577-
raise Exception("Error on fit start")
577+
raise MisconfigurationException("Error on fit start")
578578

579579

580580
@pytest.mark.parametrize("logger", (False, True))
@@ -586,8 +586,10 @@ def on_fit_start(self):
586586
pytest.param({"tpu_cores": 1}, marks=RunIf(tpu=True)),
587587
),
588588
)
589-
def test_cli_ddp_spawn_save_config_callback(tmpdir, logger, trainer_kwargs):
590-
with mock.patch("sys.argv", ["any.py", "fit"]), pytest.raises(Exception, match=r"Error on fit start"):
589+
def test_cli_distributed_save_config_callback(tmpdir, logger, trainer_kwargs):
590+
with mock.patch("sys.argv", ["any.py", "fit"]), pytest.raises(
591+
MisconfigurationException, match=r"Error on fit start"
592+
):
591593
LightningCLI(
592594
EarlyExitTestModel,
593595
trainer_defaults={

0 commit comments

Comments
 (0)