Skip to content

Commit 176df20

Browse files
authored
Mark evaluation epoch loops attributes as protected (#8420)
* Mark evaluation epoch loops attributes as protected * Fix pre-commit
1 parent 7d1f4ce commit 176df20

File tree

2 files changed

+16
-18
lines changed

2 files changed

+16
-18
lines changed

pytorch_lightning/loops/epoch/evaluation_epoch_loop.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,8 @@ def __init__(self) -> None:
3737
super().__init__()
3838
self.predictions: Optional[PredictionCollection] = None
3939
self.dataloader: Optional[Iterator] = None
40-
self.dl_max_batches: Optional[int] = None
41-
self.dataloader_idx: Optional[int] = None
42-
self.num_dataloaders: Optional[int] = None
40+
self._dl_max_batches: Optional[int] = None
41+
self._num_dataloaders: Optional[int] = None
4342
self.outputs: List[STEP_OUTPUT] = []
4443
self.progress = EpochProgress()
4544

@@ -54,15 +53,14 @@ def connect(
5453
@property
5554
def done(self) -> bool:
5655
"""Returns ``True`` if the current iteration count reaches the number of dataloader batches."""
57-
return self.iteration_count >= self.dl_max_batches
56+
return self.iteration_count >= self._dl_max_batches
5857

5958
def reset(self) -> None:
6059
"""Resets the loop's internal state."""
6160
self.iteration_count = 0
6261
self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size)
63-
self.dl_max_batches = None
64-
self.dataloader_idx = None
65-
self.num_dataloaders = None
62+
self._dl_max_batches = None
63+
self._num_dataloaders = None
6664
self.outputs = []
6765

6866
def on_run_start(
@@ -80,11 +78,9 @@ def on_run_start(
8078
dl_max_batches: maximum number of batches the dataloader can produce
8179
num_dataloaders: the total number of dataloaders
8280
"""
83-
void(dataloader_iter)
84-
85-
self.dl_max_batches = dl_max_batches
86-
self.dataloader_idx = dataloader_idx
87-
self.num_dataloaders = num_dataloaders
81+
void(dataloader_iter, dataloader_idx)
82+
self._dl_max_batches = dl_max_batches
83+
self._num_dataloaders = num_dataloaders
8884

8985
def advance(
9086
self,
@@ -182,8 +178,8 @@ def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx:
182178
"""
183179
self.trainer.logger_connector.on_batch_start()
184180

185-
assert self.num_dataloaders is not None
186-
self.trainer.logger_connector.on_evaluation_batch_start(batch, batch_idx, dataloader_idx, self.num_dataloaders)
181+
assert self._num_dataloaders is not None
182+
self.trainer.logger_connector.on_evaluation_batch_start(batch, batch_idx, dataloader_idx, self._num_dataloaders)
187183

188184
if self.trainer.testing:
189185
self.trainer.call_hook("on_test_batch_start", batch, batch_idx, dataloader_idx)
@@ -243,8 +239,8 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict
243239
# make dataloader_idx arg in validation_step optional
244240
step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)])
245241

246-
multiple_val_loaders = not self.trainer.testing and self.num_dataloaders > 1
247-
multiple_test_loaders = self.trainer.testing and self.num_dataloaders > 1
242+
multiple_val_loaders = not self.trainer.testing and self._num_dataloaders > 1
243+
multiple_test_loaders = self.trainer.testing and self._num_dataloaders > 1
248244

249245
if multiple_test_loaders or multiple_val_loaders:
250246
step_kwargs["dataloader_idx"] = dataloader_idx

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,10 @@ def select_precision_plugin(self) -> PrecisionPlugin:
459459
"You have asked for native AMP on CPU, but AMP is only available on GPU."
460460
)
461461
if not _NATIVE_AMP_AVAILABLE:
462-
msg = "You have asked for native AMP but your PyTorch version does not support it." \
463-
" Consider upgrading with `pip install torch>=1.6`."
462+
msg = (
463+
"You have asked for native AMP but your PyTorch version does not support it."
464+
" Consider upgrading with `pip install torch>=1.6`."
465+
)
464466
if _APEX_AVAILABLE:
465467
self.amp_type = AMPType.APEX
466468
msg += " We will attempt to use NVIDIA Apex for this session."

0 commit comments

Comments
 (0)