Skip to content

Commit 55a90af

Browse files
authored
pytorch_lightning.loops file structure: group by dataloader, epoch, and batch loop (#8077)
1 parent 2c43bfc commit 55a90af

23 files changed

+117
-87
lines changed

CHANGELOG.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
146146
* Simplified "should run validation" logic ([#7682](https://github.com/PyTorchLightning/pytorch-lightning/pull/7682))
147147
* Simplified logic for updating the learning rate for schedulers ([#7682](https://github.com/PyTorchLightning/pytorch-lightning/pull/7682))
148148
* Removed the `on_epoch` guard from the "should stop" validation check ([#7701](https://github.com/PyTorchLightning/pytorch-lightning/pull/7701))
149-
* Refactored internal loop interface; added new classes `FitLoop`, `TrainingEpochLoop`, `TrainingBatchLoop` ([#7871](https://github.com/PyTorchLightning/pytorch-lightning/pull/7871))
149+
* Refactored internal loop interface; added new classes `FitLoop`, `TrainingEpochLoop`, `TrainingBatchLoop` ([#7871](https://github.com/PyTorchLightning/pytorch-lightning/pull/7871), [#8077](https://github.com/PyTorchLightning/pytorch-lightning/pull/8077))
150150
* Removed `pytorch_lightning/trainer/training_loop.py` ([#7985](https://github.com/PyTorchLightning/pytorch-lightning/pull/7985))
151-
* Refactored evaluation loop interface; added new classes `DataLoaderLoop`, `EvaluationDataLoaderLoop`, `EvaluationEpochLoop` ([#7990](https://github.com/PyTorchLightning/pytorch-lightning/pull/7990))
151+
* Refactored evaluation loop interface; added new classes `DataLoaderLoop`, `EvaluationLoop`, `EvaluationEpochLoop` ([#7990](https://github.com/PyTorchLightning/pytorch-lightning/pull/7990), [#8077](https://github.com/PyTorchLightning/pytorch-lightning/pull/8077))
152152
* Removed `pytorch_lightning/trainer/evaluation_loop.py` ([#8056](https://github.com/PyTorchLightning/pytorch-lightning/pull/8056))
153153
* Restricted public access to several internal functions ([#8024](https://github.com/PyTorchLightning/pytorch-lightning/pull/8024))
154154
* Refactored trainer `_run_*` functions and separate evaluation loops ([#8065](https://github.com/PyTorchLightning/pytorch-lightning/pull/8065))
155-
* Refactored prediction loop interface; added new classes `PredictionDataLoaderLoop`, `PredictionEpochLoop` ([#7700](https://github.com/PyTorchLightning/pytorch-lightning/pull/7700))
155+
* Refactored prediction loop interface; added new classes `PredictionLoop`, `PredictionEpochLoop` ([#7700](https://github.com/PyTorchLightning/pytorch-lightning/pull/7700), [#8077](https://github.com/PyTorchLightning/pytorch-lightning/pull/8077))
156156
* Removed `pytorch_lightning/trainer/predict_loop.py` ([#8094](https://github.com/PyTorchLightning/pytorch-lightning/pull/8094))
157157

158158

pytorch_lightning/callbacks/finetuning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def _store(
285285

286286
def on_train_epoch_start(self, trainer, pl_module):
287287
"""Called when the epoch begins."""
288-
for opt_idx, optimizer in trainer.fit_loop.training_loop.batch_loop.get_active_optimizers():
288+
for opt_idx, optimizer in trainer.fit_loop.epoch_loop.batch_loop.get_active_optimizers():
289289
num_param_groups = len(optimizer.param_groups)
290290
self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx)
291291
current_param_groups = optimizer.param_groups

pytorch_lightning/core/lightning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1371,7 +1371,7 @@ def training_step(...):
13711371

13721372
# backward
13731373
self._running_manual_backward = True
1374-
self.trainer.fit_loop.training_loop.batch_loop.backward(loss, optimizer=None, opt_idx=None, *args, **kwargs)
1374+
self.trainer.fit_loop.epoch_loop.batch_loop.backward(loss, optimizer=None, opt_idx=None, *args, **kwargs)
13751375
self._running_manual_backward = False
13761376

13771377
def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None:
@@ -1471,7 +1471,7 @@ def optimizer_step(
14711471
If you are overriding this method, make sure that you pass the ``optimizer_closure`` parameter
14721472
to ``optimizer.step()`` function as shown in the examples. This ensures that
14731473
``training_step()``, ``optimizer.zero_grad()``, ``backward()`` are called within
1474-
:meth:`~pytorch_lightning.trainer.fit_loop.training_loop.batch_loop.TrainingBatchLoop.advance`.
1474+
:meth:`~pytorch_lightning.loops.training_batch_loop.TrainingBatchLoop.advance`.
14751475
14761476
Args:
14771477
epoch: Current epoch

pytorch_lightning/core/optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def toggle_model(self, sync_grad: bool = True):
120120
during the accumulation phase.
121121
Setting `sync_grad` to False will block this synchronization and improve performance.
122122
"""
123-
with self._trainer.fit_loop.training_loop.batch_loop.block_ddp_sync_behaviour(not sync_grad):
123+
with self._trainer.fit_loop.epoch_loop.batch_loop.block_ddp_sync_behaviour(not sync_grad):
124124
self._toggle_model()
125125
yield
126126
self._untoggle_model()

pytorch_lightning/loops/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
# limitations under the License.
1414

1515
from pytorch_lightning.loops.base import Loop # noqa: F401
16-
from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop # noqa: F401
17-
from pytorch_lightning.loops.dataloader.evaluation_dataloader_loop import EvaluationDataLoaderLoop # noqa: F401
16+
from pytorch_lightning.loops.batch import TrainingBatchLoop # noqa: F401
17+
from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401
18+
from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401
1819
from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401
19-
from pytorch_lightning.loops.training_batch_loop import TrainingBatchLoop # noqa: F401
20-
from pytorch_lightning.loops.training_epoch_loop import TrainingEpochLoop # noqa: F401
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from pytorch_lightning.loops.batch.training_batch_loop import TrainingBatchLoop # noqa: F401

pytorch_lightning/loops/dataloader/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@
1313
# limitations under the License.
1414

1515
from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop # noqa: F401
16-
from pytorch_lightning.loops.dataloader.evaluation_dataloader_loop import EvaluationDataLoaderLoop # noqa: F401
16+
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop # noqa: F401
17+
from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop # noqa: F401

pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py renamed to pytorch_lightning/loops/dataloader/evaluation_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919

2020
import pytorch_lightning as pl
2121
from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop
22-
from pytorch_lightning.loops.evaluation_epoch_loop import EvaluationEpochLoop
22+
from pytorch_lightning.loops.epoch.evaluation_epoch_loop import EvaluationEpochLoop
2323
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
2424
from pytorch_lightning.trainer.states import TrainerFn
2525
from pytorch_lightning.utilities.model_helpers import is_overridden
2626
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
2727

2828

29-
class EvaluationDataLoaderLoop(DataLoaderLoop):
29+
class EvaluationLoop(DataLoaderLoop):
3030
"""Loops over all dataloaders for evaluation."""
3131

3232
def __init__(self):

pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py renamed to pytorch_lightning/loops/dataloader/prediction_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55

66
import pytorch_lightning as pl
77
from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop
8-
from pytorch_lightning.loops.prediction_epoch_loop import PredictionEpochLoop
8+
from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop
99
from pytorch_lightning.plugins import DDPSpawnPlugin
1010
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1111
from pytorch_lightning.utilities.types import _PREDICT_OUTPUT
1212

1313

14-
class PredictionDataLoaderLoop(DataLoaderLoop):
14+
class PredictionLoop(DataLoaderLoop):
1515
"""Loop to run over dataloaders for prediction"""
1616

1717
def __init__(self):

0 commit comments

Comments
 (0)