Skip to content

Commit 4decbc0

Browse files
authored
Deprecate dataloader_idx from on_train_batch_start/end (#9816)
* deprecate hooks * dep todo * explicit * Apply suggestions from code review * Apply suggestions from code review * code review * base
1 parent 0561fd6 commit 4decbc0

31 files changed

+150
-67
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,7 @@ def on_train_end(self) -> None:
487487
"""Called when train ends."""
488488
return self.training_type_plugin.on_train_end()
489489

490-
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
490+
# TODO: Update this in v1.7 (deprecation: #9816)
491+
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
491492
"""Called in the training loop before anything happens for that batch."""
492-
return self.training_type_plugin.on_train_batch_start(batch, batch_idx, dataloader_idx)
493+
return self.training_type_plugin.on_train_batch_start(batch, batch_idx)

pytorch_lightning/callbacks/base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,12 @@ def on_sanity_check_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningMod
9797
pass
9898

9999
def on_train_batch_start(
100-
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
100+
self,
101+
trainer: "pl.Trainer",
102+
pl_module: "pl.LightningModule",
103+
batch: Any,
104+
batch_idx: int,
105+
unused: Optional[int] = 0,
101106
) -> None:
102107
"""Called when the train batch begins."""
103108
pass
@@ -109,7 +114,7 @@ def on_train_batch_end(
109114
outputs: STEP_OUTPUT,
110115
batch: Any,
111116
batch_idx: int,
112-
dataloader_idx: int,
117+
unused: Optional[int] = 0,
113118
) -> None:
114119
"""Called when the train batch ends."""
115120
pass

pytorch_lightning/callbacks/gpu_stats_monitor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
135135

136136
@rank_zero_only
137137
def on_train_batch_start(
138-
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
138+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int
139139
) -> None:
140140
if self._log_stats.intra_step_time:
141141
self._snap_intra_step_time = time.time()
@@ -161,7 +161,6 @@ def on_train_batch_end(
161161
outputs: STEP_OUTPUT,
162162
batch: Any,
163163
batch_idx: int,
164-
dataloader_idx: int,
165164
) -> None:
166165
if self._log_stats.inter_step_time:
167166
self._snap_inter_step_time = time.time()

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,6 @@ def on_train_batch_end(
279279
outputs: STEP_OUTPUT,
280280
batch: Any,
281281
batch_idx: int,
282-
dataloader_idx: int,
283282
) -> None:
284283
"""Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`"""
285284
if self._should_skip_saving_checkpoint(trainer):
@@ -304,9 +303,7 @@ def on_train_batch_end(
304303

305304
self.save_checkpoint(trainer)
306305

307-
def on_train_epoch_end(
308-
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", unused: Optional = None
309-
) -> None:
306+
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
310307
"""Save a checkpoint at the end of the training epoch."""
311308
# as we advance one step at end of training, we use `global_step - 1` to avoid saving duplicates
312309
trainer.fit_loop.global_step -= 1

pytorch_lightning/callbacks/progress/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def __init__(self):
3535
def disable(self):
3636
self.enable = False
3737
38-
def on_train_batch_end(self, trainer, pl_module, outputs):
39-
super().on_train_batch_end(trainer, pl_module, outputs) # don't forget this :)
38+
def on_train_batch_end(self, trainer, pl_module, outputs, batch_idx):
39+
super().on_train_batch_end(trainer, pl_module, outputs, batch_idx) # don't forget this :)
4040
percent = (self.train_batch_idx / self.total_train_batches) * 100
4141
sys.stdout.flush()
4242
sys.stdout.write(f'{percent:.01f} percent complete \r')
@@ -161,7 +161,7 @@ def on_train_start(self, trainer, pl_module):
161161
def on_train_epoch_start(self, trainer, pl_module):
162162
self._train_batch_idx = trainer.fit_loop.epoch_loop.batch_progress.current.completed
163163

164-
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
164+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
165165
self._train_batch_idx += 1
166166

167167
def on_validation_start(self, trainer, pl_module):

pytorch_lightning/callbacks/progress/rich_progress.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,8 @@ def on_predict_epoch_start(self, trainer, pl_module):
369369
super().on_predict_epoch_start(trainer, pl_module)
370370
self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description)
371371

372-
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
373-
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
372+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
373+
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
374374
self._update(self.main_progress_bar_id)
375375

376376
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):

pytorch_lightning/callbacks/progress/tqdm_progress.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,8 @@ def on_train_epoch_start(self, trainer, pl_module):
231231
reset(self.main_progress_bar, total=total_batches, current=self.train_batch_idx)
232232
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
233233

234-
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
235-
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
234+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
235+
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
236236
total_batches = self.total_train_batches + self.total_val_batches
237237
total_batches = convert_inf(total_batches)
238238
if self._should_update(self.train_batch_idx, total_batches):

pytorch_lightning/core/hooks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,25 +79,25 @@ def on_pretrain_routine_end(self) -> None:
7979
- training_start
8080
"""
8181

82-
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
82+
def on_train_batch_start(self, batch: Any, batch_idx: int, unused: Optional[int] = 0) -> None:
8383
"""Called in the training loop before anything happens for that batch.
8484
8585
If you return -1 here, you will skip training for the rest of the current epoch.
8686
8787
Args:
8888
batch: The batched data as it is returned by the training DataLoader.
8989
batch_idx: the index of the batch
90-
dataloader_idx: the index of the dataloader
90+
unused: Deprecated argument. Will be removed in v1.7.
9191
"""
9292

93-
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
93+
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, unused: Optional[int] = 0) -> None:
9494
"""Called in the training loop after the batch.
9595
9696
Args:
9797
outputs: The outputs of training_step_end(training_step(x))
9898
batch: The batched data as it is returned by the training DataLoader.
9999
batch_idx: the index of the batch
100-
dataloader_idx: the index of the dataloader
100+
unused: Deprecated argument. Will be removed in v1.7.
101101
"""
102102

103103
def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:

pytorch_lightning/loops/batch/training_batch_loop.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from pytorch_lightning.loops.utilities import _get_active_optimizers
2525
from pytorch_lightning.trainer.supporters import TensorRunningAccum
2626
from pytorch_lightning.utilities import AttributeDict
27+
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
2728
from pytorch_lightning.utilities.warnings import WarningCache
2829

2930
_OUTPUTS_TYPE = List[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]]
@@ -76,7 +77,14 @@ def run(self, batch: Any, batch_idx: int) -> AttributeDict:
7677
return AttributeDict(signal=-1)
7778

7879
# hook
79-
response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, 0)
80+
# TODO: Update this in v1.7 (deprecation: #9816)
81+
model_fx = self.trainer.lightning_module.on_train_batch_start
82+
extra_kwargs = (
83+
{"dataloader_idx": 0}
84+
if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True)
85+
else {}
86+
)
87+
response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
8088
if response == -1:
8189
return AttributeDict(signal=-1)
8290

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2828
from pytorch_lightning.utilities.fetching import AbstractDataFetcher
2929
from pytorch_lightning.utilities.model_helpers import is_overridden
30+
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
3031

3132
_OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE]
3233

@@ -170,7 +171,15 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
170171
automatic=self.trainer.lightning_module.trainer.lightning_module.automatic_optimization,
171172
num_optimizers=len(self.trainer.optimizers),
172173
)
173-
self.trainer.call_hook("on_train_batch_end", batch_end_outputs, batch, self.batch_idx, 0)
174+
175+
# TODO: Update this in v1.7 (deprecation: #9816)
176+
model_fx = self.trainer.lightning_module.on_train_batch_end
177+
extra_kwargs = (
178+
{"dataloader_idx": 0}
179+
if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True)
180+
else {}
181+
)
182+
self.trainer.call_hook("on_train_batch_end", batch_end_outputs, batch, batch_idx, **extra_kwargs)
174183
self.trainer.call_hook("on_batch_end")
175184
self.trainer.logger_connector.on_batch_end()
176185

0 commit comments

Comments
 (0)