Skip to content

Commit 2242423

Browse files
refactor accelerator teardown -> training type plugin teardown (#7579)
1 parent a8d9b5f commit 2242423

File tree

15 files changed

+237
-32
lines changed

15 files changed

+237
-32
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8282
- MLFlowLogger now accepts `run_name` as an constructor argument ([#7622](https://github.com/PyTorchLightning/pytorch-lightning/issues/7622))
8383

8484

85+
- Changed `teardown()` in `Accelerator` to allow `training_type_plugin` to customize `teardown` logic ([#7579](https://github.com/PyTorchLightning/pytorch-lightning/pull/7579))
86+
87+
8588
### Deprecated
8689

8790

pytorch_lightning/accelerators/accelerator.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,17 +151,15 @@ def lightning_module(self) -> 'pl.LightningModule':
151151

152152
@property
153153
def root_device(self) -> torch.device:
154+
"""Returns the root device"""
154155
return self.training_type_plugin.root_device
155156

156157
def teardown(self) -> None:
157158
"""
158159
This method is called to teardown the training process.
159-
It is the right place to release memory and free other ressources.
160-
161-
By default we add a barrier here to synchronize processes before returning
162-
control back to the caller.
160+
It is the right place to release memory and free other resources.
163161
"""
164-
self.barrier("teardown")
162+
self.training_type_plugin.teardown()
165163

166164
def batch_to_device(
167165
self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: Optional[int] = None

pytorch_lightning/accelerators/gpu.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,6 @@ def on_train_start(self) -> None:
4747
with torch.cuda.device(self.root_device):
4848
torch.cuda.empty_cache()
4949

50-
def teardown(self) -> None:
51-
self.lightning_module.cpu()
52-
53-
# clean up memory
54-
with torch.cuda.device(self.root_device):
55-
torch.cuda.empty_cache()
56-
5750
@staticmethod
5851
def set_nvidia_flags(local_rank: int) -> None:
5952
# set the correct cuda visible devices (using pci order)

pytorch_lightning/accelerators/tpu.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import os
1514
from typing import Any, Callable
1615

1716
from torch.optim import Optimizer
@@ -51,10 +50,6 @@ def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
5150
raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.")
5251
return super().setup(trainer, model)
5352

54-
def teardown(self) -> None:
55-
if "PT_XLA_DEBUG" in os.environ:
56-
del os.environ["PT_XLA_DEBUG"]
57-
5853
def run_optimizer_step(
5954
self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
6055
) -> None:

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def __init__(
9696
self.set_world_ranks()
9797

9898
@property
99-
def root_device(self):
99+
def root_device(self) -> torch.device:
100100
return self.parallel_devices[self.local_rank]
101101

102102
@property
@@ -126,7 +126,7 @@ def distributed_sampler_kwargs(self):
126126
def _is_single_process_single_device(self) -> bool:
127127
return True
128128

129-
def setup_environment(self):
129+
def setup_environment(self) -> None:
130130
# start the other scripts
131131
if not self.cluster_environment.creates_children() and os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1":
132132
self._call_children_scripts()

pytorch_lightning/plugins/training_type/parallel.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pytorch_lightning.overrides.base import unwrap_lightning_module
2424
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
2525
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
26+
from pytorch_lightning.utilities import _XLA_AVAILABLE
2627
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp
2728

2829

@@ -40,13 +41,17 @@ def __init__(
4041

4142
@property
4243
@abstractmethod
43-
def root_device(self):
44+
def root_device(self) -> torch.device:
4445
raise NotImplementedError
4546

4647
@property
47-
def on_gpu(self):
48+
def on_gpu(self) -> bool:
4849
return self.root_device.type == "cuda" and torch.cuda.is_available()
4950

51+
@property
52+
def on_tpu(self) -> bool:
53+
return self.root_device.type == "xla" and _XLA_AVAILABLE
54+
5055
@property
5156
def lightning_module(self):
5257
return unwrap_lightning_module(self._model)
@@ -122,3 +127,11 @@ def block_backward_sync(self):
122127
yield None
123128
else:
124129
yield None
130+
131+
def teardown(self) -> None:
132+
if self.on_gpu:
133+
# GPU teardown
134+
self.lightning_module.cpu()
135+
# clean up memory
136+
with torch.cuda.device(self.root_device):
137+
torch.cuda.empty_cache()

pytorch_lightning/plugins/training_type/single_device.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import torch
1717

1818
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
19+
from pytorch_lightning.utilities import _XLA_AVAILABLE
1920

2021

2122
class SingleDevicePlugin(TrainingTypePlugin):
@@ -30,11 +31,11 @@ def __init__(self, device: torch.device):
3031

3132
@property
3233
def on_tpu(self) -> bool:
33-
return False
34+
return self.root_device.type == "xla" and _XLA_AVAILABLE
3435

3536
@property
3637
def on_gpu(self) -> bool:
37-
return self.device.type == "cuda" and torch.cuda.is_available()
38+
return self.root_device.type == "cuda" and torch.cuda.is_available()
3839

3940
def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]:
4041
"""
@@ -78,3 +79,11 @@ def barrier(self, *args, **kwargs) -> None:
7879

7980
def broadcast(self, obj: object, src: int = 0) -> object:
8081
return obj
82+
83+
def teardown(self) -> None:
84+
if self.on_gpu:
85+
# GPU teardown
86+
self.lightning_module.cpu()
87+
# clean up memory
88+
with torch.cuda.device(self.root_device):
89+
torch.cuda.empty_cache()

pytorch_lightning/plugins/training_type/single_tpu.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,6 @@ def __init__(self, device: int, debug: bool = False):
3535
self.tpu_local_core_rank = 0
3636
self.tpu_global_core_rank = 0
3737

38-
@property
39-
def on_tpu(self) -> bool:
40-
return True
41-
4238
@property
4339
def is_distributed(self) -> bool:
4440
return False
@@ -63,3 +59,7 @@ def on_save(self, checkpoint: dict) -> dict:
6359
https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors
6460
"""
6561
return move_data_to_device(checkpoint, torch.device("cpu"))
62+
63+
def teardown(self) -> None:
64+
# TPU teardown
65+
os.environ.pop("PT_XLA_DEBUG", None)

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def world_size(self) -> int:
7171

7272
@property
7373
def root_device(self) -> torch.device:
74-
return self.device
74+
return xm.xla_device()
7575

7676
@staticmethod
7777
def _validate_dataloader(dataloaders: Union[List[DataLoader], DataLoader]) -> None:
@@ -129,7 +129,7 @@ def is_distributed(self) -> bool:
129129

130130
def process_dataloader(self, dataloader: DataLoader) -> MpDeviceLoader:
131131
TPUSpawnPlugin._validate_dataloader(dataloader)
132-
return MpDeviceLoader(dataloader, self.device)
132+
return MpDeviceLoader(dataloader, self.root_device)
133133

134134
def configure_ddp(self) -> None:
135135
pass
@@ -172,8 +172,7 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
172172
time.sleep(2)
173173

174174
def model_to_device(self) -> None:
175-
self.device = xm.xla_device()
176-
self.model = self.wrapped_model.to(self.device)
175+
self.model = self.wrapped_model.to(self.root_device)
177176

178177
def barrier(self, name: Optional[str] = None) -> None:
179178
# HOST_WORLD_SIZE is None outside the xmp.spawn process
@@ -209,7 +208,7 @@ def broadcast(self, obj: object, src: int = 0) -> object:
209208
buffer = io.BytesIO()
210209
torch.save(obj, buffer)
211210
data = bytearray(buffer.getbuffer())
212-
data_tensor = torch.tensor(data, device=self.device, dtype=torch.float)
211+
data_tensor = torch.tensor(data, device=self.root_device, dtype=torch.float)
213212
data = xm.all_gather(data_tensor)
214213
buffer = io.BytesIO(data.cpu().byte().numpy())
215214
obj = torch.load(buffer)
@@ -302,3 +301,8 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
302301
if isinstance(tensor, torch.Tensor) and tensor.dim() == 0:
303302
tensor = tensor.unsqueeze(0)
304303
return xm.all_gather(tensor)
304+
305+
def teardown(self) -> None:
306+
# TPU teardown
307+
os.environ.pop("PT_XLA_DEBUG", None)
308+
self.barrier("teardown")

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,19 @@ def setup(self, model: Module) -> None:
6060
@abstractmethod
6161
def on_gpu(self) -> bool:
6262
"""Returns whether the current process is done on GPU"""
63+
raise NotImplementedError
64+
65+
@property
66+
@abstractmethod
67+
def on_tpu(self) -> bool:
68+
"""Returns whether the current process is done on TPU"""
69+
raise NotImplementedError
6370

6471
@property
6572
@abstractmethod
6673
def root_device(self) -> torch.device:
6774
"""Returns the root device"""
75+
raise NotImplementedError
6876

6977
@abstractmethod
7078
def model_to_device(self) -> None:
@@ -290,6 +298,14 @@ def call_configure_sharded_model_hook(self) -> bool:
290298
def call_configure_sharded_model_hook(self, mode: bool) -> None:
291299
self._call_configure_sharded_model_hook = mode
292300

301+
@abstractmethod
302+
def teardown(self) -> None:
303+
"""
304+
This method is called to teardown the training process.
305+
It is the right place to release memory and free other resources.
306+
"""
307+
raise NotImplementedError
308+
293309
@classmethod
294310
def register_plugins(cls, plugin_registry):
295311
pass

0 commit comments

Comments
 (0)