@@ -37,9 +37,8 @@ def __init__(self) -> None:
3737 super ().__init__ ()
3838 self .predictions : Optional [PredictionCollection ] = None
3939 self .dataloader : Optional [Iterator ] = None
40- self .dl_max_batches : Optional [int ] = None
41- self .dataloader_idx : Optional [int ] = None
42- self .num_dataloaders : Optional [int ] = None
40+ self ._dl_max_batches : Optional [int ] = None
41+ self ._num_dataloaders : Optional [int ] = None
4342 self .outputs : List [STEP_OUTPUT ] = []
4443 self .progress = EpochProgress ()
4544
@@ -54,15 +53,14 @@ def connect(
5453 @property
5554 def done (self ) -> bool :
5655 """Returns ``True`` if the current iteration count reaches the number of dataloader batches."""
57- return self .iteration_count >= self .dl_max_batches
56+ return self .iteration_count >= self ._dl_max_batches
5857
5958 def reset (self ) -> None :
6059 """Resets the loop's internal state."""
6160 self .iteration_count = 0
6261 self .predictions = PredictionCollection (self .trainer .global_rank , self .trainer .world_size )
63- self .dl_max_batches = None
64- self .dataloader_idx = None
65- self .num_dataloaders = None
62+ self ._dl_max_batches = None
63+ self ._num_dataloaders = None
6664 self .outputs = []
6765
6866 def on_run_start (
@@ -80,11 +78,9 @@ def on_run_start(
8078 dl_max_batches: maximum number of batches the dataloader can produce
8179 num_dataloaders: the total number of dataloaders
8280 """
83- void (dataloader_iter )
84-
85- self .dl_max_batches = dl_max_batches
86- self .dataloader_idx = dataloader_idx
87- self .num_dataloaders = num_dataloaders
81+ void (dataloader_iter , dataloader_idx )
82+ self ._dl_max_batches = dl_max_batches
83+ self ._num_dataloaders = num_dataloaders
8884
8985 def advance (
9086 self ,
@@ -182,8 +178,8 @@ def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx:
182178 """
183179 self .trainer .logger_connector .on_batch_start ()
184180
185- assert self .num_dataloaders is not None
186- self .trainer .logger_connector .on_evaluation_batch_start (batch , batch_idx , dataloader_idx , self .num_dataloaders )
181+ assert self ._num_dataloaders is not None
182+ self .trainer .logger_connector .on_evaluation_batch_start (batch , batch_idx , dataloader_idx , self ._num_dataloaders )
187183
188184 if self .trainer .testing :
189185 self .trainer .call_hook ("on_test_batch_start" , batch , batch_idx , dataloader_idx )
@@ -243,8 +239,8 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict
243239 # make dataloader_idx arg in validation_step optional
244240 step_kwargs = OrderedDict ([("batch" , batch ), ("batch_idx" , batch_idx )])
245241
246- multiple_val_loaders = not self .trainer .testing and self .num_dataloaders > 1
247- multiple_test_loaders = self .trainer .testing and self .num_dataloaders > 1
242+ multiple_val_loaders = not self .trainer .testing and self ._num_dataloaders > 1
243+ multiple_test_loaders = self .trainer .testing and self ._num_dataloaders > 1
248244
249245 if multiple_test_loaders or multiple_val_loaders :
250246 step_kwargs ["dataloader_idx" ] = dataloader_idx
0 commit comments