Skip to content

Commit f69bb3a

Browse files
awaelchlipre-commit-ci[bot]justusschockrohitgr7kaushikb11
committed
Fix BF16 teardown for TPU precision plugin (#10990)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Justus Schock <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Kaushik B <[email protected]> Co-authored-by: thomas chaton <[email protected]>
1 parent 1fda63c commit f69bb3a

File tree

11 files changed

+41
-5
lines changed

11 files changed

+41
-5
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1414
- Fixed data fetcher selection ([#11294](https://github.com/PyTorchLightning/pytorch-lightning/pull/11294))
1515
- Fixed a race condition that could result in incorrect (zero) values being observed in prediction writer callbacks ([#11288](https://github.com/PyTorchLightning/pytorch-lightning/pull/11288))
1616
- Fixed dataloaders not getting reloaded the correct amount of times when setting `reload_dataloaders_every_n_epochs` and `check_val_every_n_epoch` ([#10948](https://github.com/PyTorchLightning/pytorch-lightning/pull/10948))
17+
- Fixed an issue with the `TPUSpawnPlugin` handling the `XLA_USE_BF16` environment variable incorrectly ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/pull/10990))
18+
1719

1820
## [1.5.7] - 2021-12-21
1921

pytorch_lightning/accelerators/accelerator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def teardown(self) -> None:
188188
It is the right place to release memory and free other resources.
189189
"""
190190
self.training_type_plugin.teardown()
191+
self.precision_plugin.teardown()
191192

192193
def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any:
193194
"""Moves the batch to the correct device. The returned batch is of the same type as the input batch, just

pytorch_lightning/plugins/precision/precision_plugin.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,3 +252,9 @@ def predict_step_context(self) -> Generator[None, None, None]:
252252
"""A contextmanager for the predict step."""
253253
with self.forward_context():
254254
yield
255+
256+
def teardown(self) -> None:
257+
"""This method is called to teardown the training process.
258+
259+
It is the right place to release memory and free other resources.
260+
"""

pytorch_lightning/plugins/precision/tpu_bf16.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,8 @@ class TPUBf16PrecisionPlugin(TPUPrecisionPlugin):
2828
def connect(
2929
self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any]
3030
) -> Tuple[nn.Module, List[Optimizer], List[Any]]:
31-
os.environ["XLA_USE_BF16"] = str(1)
31+
os.environ["XLA_USE_BF16"] = "1"
3232
return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers)
33+
34+
def teardown(self) -> None:
35+
os.environ.pop("XLA_USE_BF16", None)

pytorch_lightning/plugins/training_type/single_device.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def broadcast(self, obj: object, src: int = 0) -> object:
8181
return obj
8282

8383
def teardown(self) -> None:
84+
super().teardown()
8485
if self.on_gpu:
8586
# GPU teardown
8687
self.lightning_module.cpu()

pytorch_lightning/plugins/training_type/single_tpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
8282
return self.checkpoint_io.save_checkpoint(checkpoint, filepath)
8383

8484
def teardown(self) -> None:
85+
super().teardown()
8586
# TPU teardown
8687
os.environ.pop("PT_XLA_DEBUG", None)
8788

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
357357
return xm.all_gather(tensor)
358358

359359
def teardown(self) -> None:
360-
# TPU teardown
360+
super().teardown()
361361
os.environ.pop("PT_XLA_DEBUG", None)
362362
self.barrier("teardown")
363363

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,6 @@ def model_sharded_context(self) -> Generator:
312312
"""
313313
yield
314314

315-
@abstractmethod
316315
def teardown(self) -> None:
317316
"""This method is called to teardown the training process.
318317

tests/models/test_tpu.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ def test_model_16bit_tpu_cores_1(tmpdir):
122122

123123
model = BoringModel()
124124
tpipes.run_model_test(trainer_options, model, on_gpu=False)
125-
assert os.environ.get("XLA_USE_BF16") == str(1), "XLA_USE_BF16 was not set in environment variables"
126125

127126

128127
@pytest.mark.parametrize("tpu_core", [1, 5])
@@ -144,7 +143,6 @@ def test_model_16bit_tpu_index(tmpdir, tpu_core):
144143
model = BoringModel()
145144
tpipes.run_model_test(trainer_options, model, on_gpu=False)
146145
assert torch_xla._XLAC._xla_get_default_device() == f"xla:{tpu_core}"
147-
assert os.environ.get("XLA_USE_BF16") == str(1), "XLA_USE_BF16 was not set in environment variables"
148146

149147

150148
@RunIf(tpu=True)

tests/plugins/precision/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)