Skip to content

Commit ba8e7cd

Browse files
awaelchlipre-commit-ci[bot]justusschockrohitgr7kaushikb11
authored
Fix BF16 teardown for TPU precision plugin (#10990)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Justus Schock <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Kaushik B <[email protected]> Co-authored-by: thomas chaton <[email protected]>
1 parent 235efb3 commit ba8e7cd

File tree

10 files changed

+54
-14
lines changed

10 files changed

+54
-14
lines changed

CHANGELOG.md

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5555
- Added `console_kwargs` for `RichProgressBar` to initialize inner Console ([#10875](https://github.com/PyTorchLightning/pytorch-lightning/pull/10875))
5656

5757

58+
- Added a `PrecisionPlugin.teardown` method ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/issues/10990))
59+
60+
61+
5862
### Changed
5963

6064
- Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418))
@@ -140,16 +144,17 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
140144
* Renamed the `ParallelPlugin` to `ParallelStrategy` ([#11123](https://github.com/PyTorchLightning/pytorch-lightning/pull/11123))
141145
* Renamed the `DataParallelPlugin` to `DataParallelStrategy` ([#11183](https://github.com/PyTorchLightning/pytorch-lightning/pull/11183))
142146
* Renamed the `DDPPlugin` to `DDPStrategy` ([#11142](https://github.com/PyTorchLightning/pytorch-lightning/pull/11142))
143-
* Renamed the `DeepSpeedPlugin` to `DeepSpeedStrategy` ([#11194](https://github.com/PyTorchLightning/pytorch-lightning/pull/11194))
144-
* Renamed the `IPUPlugin` to `IPUStrategy` ([#11193](https://github.com/PyTorchLightning/pytorch-lightning/pull/11193))
145-
* Renamed the `TPUSpawnPlugin` to `TPUSpawnStrategy` ([#11190](https://github.com/PyTorchLightning/pytorch-lightning/pull/11190))
146-
* Renamed the `DDPShardedPlugin` to `DDPShardedStrategy` ([#11186](https://github.com/PyTorchLightning/pytorch-lightning/pull/11186))
147147
* Renamed the `DDP2Plugin` to `DDP2Strategy` ([#11185](https://github.com/PyTorchLightning/pytorch-lightning/pull/11185))
148-
* Renamed the `SingleTPUPlugin` to `SingleTPUStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182))
149-
* Renamed the `DDPSpawnPlugin` to `DDPSpawnStrategy` ([#11145](https://github.com/PyTorchLightning/pytorch-lightning/pull/11145))
148+
* Renamed the `DDPShardedPlugin` to `DDPShardedStrategy` ([#11186](https://github.com/PyTorchLightning/pytorch-lightning/pull/11186))
150149
* Renamed the `DDPFullyShardedPlugin` to `DDPFullyShardedStrategy` ([#11143](https://github.com/PyTorchLightning/pytorch-lightning/pull/11143))
151-
* Renamed the `SingleDevicePlugin` to `SingleDeviceStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182))
150+
* Renamed the `DDPSpawnPlugin` to `DDPSpawnStrategy` ([#11145](https://github.com/PyTorchLightning/pytorch-lightning/pull/11145))
152151
* Renamed the `DDPSpawnShardedPlugin` to `DDPSpawnShardedStrategy` ([#11210](https://github.com/PyTorchLightning/pytorch-lightning/pull/11210))
152+
* Renamed the `DeepSpeedPlugin` to `DeepSpeedStrategy` ([#11194](https://github.com/PyTorchLightning/pytorch-lightning/pull/11194))
153+
* Renamed the `HorovodPlugin` to `HorovodStrategy` ([#11195](https://github.com/PyTorchLightning/pytorch-lightning/pull/11195))
154+
* Renamed the `TPUSpawnPlugin` to `TPUSpawnStrategy` ([#11190](https://github.com/PyTorchLightning/pytorch-lightning/pull/11190))
155+
* Renamed the `IPUPlugin` to `IPUStrategy` ([#11193](https://github.com/PyTorchLightning/pytorch-lightning/pull/11193))
156+
* Renamed the `SingleDevicePlugin` to `SingleDeviceStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182))
157+
* Renamed the `SingleTPUPlugin` to `SingleTPUStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182))
153158

154159

155160
- Marked the `ResultCollection`, `ResultMetric`, and `ResultMetricCollection` classes as protected ([#11130](https://github.com/PyTorchLightning/pytorch-lightning/pull/11130))
@@ -337,6 +342,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
337342
- Fixed double evaluation bug with fault-tolerance enabled where the second call was completely skipped ([#11119](https://github.com/PyTorchLightning/pytorch-lightning/pull/11119))
338343

339344

345+
- Fixed an issue with the `TPUSpawnPlugin` handling the `XLA_USE_BF16` environment variable incorrectly ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/pull/10990))
346+
347+
340348

341349
## [1.5.7] - 2021-12-21
342350

pytorch_lightning/plugins/precision/precision_plugin.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,3 +236,9 @@ def predict_step_context(self) -> Generator[None, None, None]:
236236
"""A contextmanager for the predict step."""
237237
with self.forward_context():
238238
yield
239+
240+
def teardown(self) -> None:
241+
"""This method is called to teardown the training process.
242+
243+
It is the right place to release memory and free other resources.
244+
"""

pytorch_lightning/plugins/precision/tpu_bf16.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,8 @@ class TPUBf16PrecisionPlugin(TPUPrecisionPlugin):
2828
def connect(
2929
self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any]
3030
) -> Tuple[nn.Module, List[Optimizer], List[Any]]:
31-
os.environ["XLA_USE_BF16"] = str(1)
31+
os.environ["XLA_USE_BF16"] = "1"
3232
return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers)
33+
34+
def teardown(self) -> None:
35+
os.environ.pop("XLA_USE_BF16", None)

pytorch_lightning/plugins/training_type/single_device.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def broadcast(self, obj: object, src: int = 0) -> object:
8686
return obj
8787

8888
def teardown(self) -> None:
89+
super().teardown()
8990
if self.on_gpu:
9091
# GPU teardown
9192
self.lightning_module.cpu()

pytorch_lightning/plugins/training_type/single_tpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def model_to_device(self) -> None:
7474
self.model.to(self.root_device)
7575

7676
def teardown(self) -> None:
77+
super().teardown()
7778
# TPU teardown
7879
os.environ.pop("PT_XLA_DEBUG", None)
7980

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,6 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st
244244
}
245245

246246
def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]:
247-
# todo: precision pluging is call in accelerator setup and should be moved
248-
if "XLA_USE_BF16" in os.environ:
249-
del os.environ["XLA_USE_BF16"]
250247
context = mp.get_context(self.start_method or "fork")
251248
return_queue = context.SimpleQueue()
252249
xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs())
@@ -340,6 +337,7 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
340337
return xm.all_gather(tensor)
341338

342339
def teardown(self) -> None:
340+
super().teardown()
343341
os.environ.pop("PT_XLA_DEBUG", None)
344342

345343
@classmethod

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,13 +437,13 @@ def model_sharded_context(self) -> Generator:
437437
"""
438438
yield
439439

440-
@abstractmethod
441440
def teardown(self) -> None:
442441
"""This method is called to teardown the training process.
443442
444443
It is the right place to release memory and free other resources.
445444
"""
446445
self._move_optimizer_state(torch.device("cpu"))
446+
self.precision_plugin.teardown()
447447

448448
@classmethod
449449
def register_plugins(cls, plugin_registry) -> None:

tests/models/test_tpu.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ def test_model_16bit_tpu_cores_1(tmpdir):
122122

123123
model = BoringModel()
124124
tpipes.run_model_test(trainer_options, model, on_gpu=False)
125-
assert os.environ.get("XLA_USE_BF16") == str(1), "XLA_USE_BF16 was not set in environment variables"
126125

127126

128127
@pytest.mark.parametrize("tpu_core", [1, 5])
@@ -144,7 +143,6 @@ def test_model_16bit_tpu_index(tmpdir, tpu_core):
144143
model = BoringModel()
145144
tpipes.run_model_test(trainer_options, model, on_gpu=False)
146145
assert torch_xla._XLAC._xla_get_default_device() == f"xla:{tpu_core}"
147-
assert os.environ.get("XLA_USE_BF16") == str(1), "XLA_USE_BF16 was not set in environment variables"
148146

149147

150148
@RunIf(tpu=True)

tests/plugins/precision/__init__.py

Whitespace-only changes.
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
from unittest.mock import Mock
16+
17+
from pytorch_lightning.plugins import TPUBf16PrecisionPlugin
18+
19+
20+
def test_teardown():
21+
plugin = TPUBf16PrecisionPlugin()
22+
plugin.connect(Mock(), Mock(), Mock())
23+
assert os.environ.get("XLA_USE_BF16") == "1"
24+
plugin.teardown()
25+
assert "XLA_USE_BF16" not in os.environ

0 commit comments

Comments
 (0)