From 5cc995938544655fd26ba189fb9215240366a4c3 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Tue, 5 Jan 2021 16:14:03 +0530 Subject: [PATCH 01/27] added on_post_move_to_device --- pytorch_lightning/core/decorators.py | 21 +++++++++++++++++++ pytorch_lightning/core/hooks.py | 16 ++++++++++++++ .../utilities/device_dtype_mixin.py | 3 +++ 3 files changed, 40 insertions(+) diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index 938db9cc20b00..0a60c670dd01d 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -18,6 +18,7 @@ from typing import Callable from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities import rank_zero_warn def auto_move_data(fn: Callable) -> Callable: @@ -64,3 +65,23 @@ def auto_transfer_args(self, *args, **kwargs): return fn(self, *args, **kwargs) return auto_transfer_args + + +def parameter_validation(fn: Callable) -> Callable: + @wraps(fn) + def inner_f(self, *args, **kwargs): + if not isinstance(self, LightningModule): + return fn(self, *args, **kwargs) + pre_param_count = len(list(self.parameters())) + module = fn(self, *args, **kwargs) + self.on_post_move_to_device() + post_param_count = len(list(self.parameters())) + + if not pre_param_count == post_param_count: + rank_zero_warn('The parameter count does not match after moving target device. ' + 'If your model employs weight sharing on TPU,' + 'please tie your weights in the `on_post_move_to_device` model hook.') + + return module + + return inner_f diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index f27c18513831f..83206104b41fe 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -314,6 +314,22 @@ def on_after_backward(self): """ + def on_post_move_to_device(self) -> None: + """ + Called in the parameter_validation decorator after Lightning.to is called + This is a good place to tie weights between modules after moving them to a device. + Can be used when training models with weight sharing properties on TPU. + + Addresses the handling of shared weights on TPU: + https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks + + Example:: + + def on_post_move_to_device(self): + self.decoder.weight.data = self.encoder.weight.data.transpose(0,1) + + """ + class DataHooks: """Hooks to be used with LightningDataModule.""" diff --git a/pytorch_lightning/utilities/device_dtype_mixin.py b/pytorch_lightning/utilities/device_dtype_mixin.py index 79182cd086f6f..63de2a029ab53 100644 --- a/pytorch_lightning/utilities/device_dtype_mixin.py +++ b/pytorch_lightning/utilities/device_dtype_mixin.py @@ -17,6 +17,8 @@ import torch from torch.nn import Module +from pytorch_lightning.core.decorators import parameter_validation + class DeviceDtypeModuleMixin(Module): __jit_unused_properties__ = ['device', 'dtype'] @@ -50,6 +52,7 @@ def device(self, new_device: Union[str, torch.device]): # Necessary to avoid infinite recursion raise RuntimeError('Cannot set the device explicitly. Please use module.to(new_device).') + @parameter_validation def to(self, *args, **kwargs) -> Module: """Moves and/or casts the parameters and buffers. From 3d1a3138b44d1dd47c72e8a81b5cf2f4f99f077e Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Sat, 9 Jan 2021 17:54:11 +0530 Subject: [PATCH 02/27] added tests --- CHANGELOG.md | 1 + .../accelerators/tpu_accelerator.py | 21 +++--- pytorch_lightning/core/decorators.py | 8 +-- pytorch_lightning/core/hooks.py | 2 +- .../utilities/device_dtype_mixin.py | 3 + tests/backends/test_tpu_backend.py | 67 +++++++++++++++++++ 6 files changed, 84 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b56321765bbf0..6104667549215 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -50,6 +50,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `IoU` class interface ([#4704](https://github.com/PyTorchLightning/pytorch-lightning/pull/4704)) +- Support to tie weights after moving model to TPU via `on_post_move_to_device` hook - Added missing val/test hooks in `LightningModule` ([#5467](https://github.com/PyTorchLightning/pytorch-lightning/pull/5467)) diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index 66fc236a2a775..985486e50c631 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -194,18 +194,6 @@ def __save_end_of_training_weights(self, model: LightningModule, trainer): self.save_spawn_weights(model) def __setup_tpu_training(self, model: LightningModule, trainer): - # use the default device from the process - # tpu_device = xm.xla_device() - - # if given an ordinal device, use this as the device - if trainer.tpu_id is not None: - tpu_device = xm.xla_device(trainer.tpu_id) - else: - tpu_device = xm.xla_device() - # track the device and move model to it - trainer._device = tpu_device - model.to(trainer._device) - # get the appropriate tpu ranks trainer.tpu_local_core_rank = xm.get_local_ordinal() trainer.tpu_global_core_rank = xm.get_ordinal() @@ -217,6 +205,15 @@ def __setup_tpu_training(self, model: LightningModule, trainer): trainer.global_rank = trainer.tpu_local_core_rank rank_zero_only.rank = trainer.global_rank + # if given an ordinal device, use this as the device + if trainer.tpu_id is not None: + tpu_device = xm.xla_device(trainer.tpu_id) + else: + tpu_device = xm.xla_device() + # track the device and move model to it + trainer._device = tpu_device + model.to(trainer._device) + # CHOOSE OPTIMIZER # allow for lr schedulers as well self.setup_optimizers(model) diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index 0a60c670dd01d..af9d2e3bb7323 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -17,7 +17,6 @@ from functools import wraps from typing import Callable -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_warn @@ -57,6 +56,7 @@ def forward(self, x): """ @wraps(fn) def auto_transfer_args(self, *args, **kwargs): + from pytorch_lightning.core.lightning import LightningModule if not isinstance(self, LightningModule): return fn(self, *args, **kwargs) @@ -70,17 +70,15 @@ def auto_transfer_args(self, *args, **kwargs): def parameter_validation(fn: Callable) -> Callable: @wraps(fn) def inner_f(self, *args, **kwargs): - if not isinstance(self, LightningModule): - return fn(self, *args, **kwargs) pre_param_count = len(list(self.parameters())) module = fn(self, *args, **kwargs) self.on_post_move_to_device() post_param_count = len(list(self.parameters())) if not pre_param_count == post_param_count: - rank_zero_warn('The parameter count does not match after moving target device. ' + rank_zero_warn('The model parameters do not match after moving to the target device. ' 'If your model employs weight sharing on TPU,' - 'please tie your weights in the `on_post_move_to_device` model hook.') + 'please tie your weights using the `on_post_move_to_device` model hook.') return module diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 83206104b41fe..00f151924e704 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -326,7 +326,7 @@ def on_post_move_to_device(self) -> None: Example:: def on_post_move_to_device(self): - self.decoder.weight.data = self.encoder.weight.data.transpose(0,1) + self.decoder.weight = self.encoder.weight """ diff --git a/pytorch_lightning/utilities/device_dtype_mixin.py b/pytorch_lightning/utilities/device_dtype_mixin.py index 63de2a029ab53..3627c0c493747 100644 --- a/pytorch_lightning/utilities/device_dtype_mixin.py +++ b/pytorch_lightning/utilities/device_dtype_mixin.py @@ -89,6 +89,9 @@ def to(self, *args, **kwargs) -> Module: ... def __init__(self, weight: torch.Tensor): ... super().__init__() ... self.register_buffer('weight', weight) + ... + ... def on_post_move_to_device(self): + ... pass >>> _ = torch.manual_seed(0) >>> module = ExampleModule(torch.rand(3, 4)) >>> module.weight #doctest: +ELLIPSIS diff --git a/tests/backends/test_tpu_backend.py b/tests/backends/test_tpu_backend.py index 17e67755fafd7..954deb9c4cdb5 100644 --- a/tests/backends/test_tpu_backend.py +++ b/tests/backends/test_tpu_backend.py @@ -14,9 +14,11 @@ import pytest import torch +from torch import nn from pytorch_lightning import Trainer from pytorch_lightning.trainer.states import TrainerState +from tests.base import SimpleModule from pytorch_lightning.utilities.xla_device import XLADeviceUtils from tests.base.boring_model import BoringModel from tests.base.develop_utils import pl_multi_process_test @@ -61,3 +63,68 @@ def test_if_test_works_after_train(tmpdir): trainer.fit(model) assert trainer.test() == 1 + + +@pytest.mark.skipif(not XLADeviceUtils.tpu_device_exists(), reason="test requires TPU machine") +@pl_multi_process_test +def test_weight_tying_warning(tmpdir, capsys=None): + """ + Ensure a warning is thrown if model parameter lengths do not match + post moving to device. + """ + + class WeightSharingModule(SimpleModule): + def __init__(self): + super().__init__() + self.layer_1 = nn.Linear(32, 10, bias=False) + self.layer_2 = nn.Linear(10, 32, bias=False) + self.layer_3 = nn.Linear(32, 10, bias=False) + self.layer_3.weight = self.layer_1.weight + + def forward(self, x): + x = self.layer_1(x) + x = self.layer_2(x) + x = self.layer_3(x) + return x + + model = WeightSharingModule() + trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1) + + with pytest.warns(UserWarning, match=r'The model parameters do not match after moving to the target device.'): + result = trainer.fit(model) + assert result + + +@pytest.mark.skipif(not XLADeviceUtils.tpu_device_exists(), reason="test requires TPU machine") +@pl_multi_process_test +def test_if_weights_tied(tmpdir, capsys=None): + """ + Test if weights are properly tied on `on_post_move_to_device`. + Ensure no warning for parameter mismatch is thrown. + """ + + class WeightSharingModule(SimpleModule): + def __init__(self): + super().__init__() + self.layer_1 = nn.Linear(32, 10, bias=False) + self.layer_2 = nn.Linear(10, 32, bias=False) + self.layer_3 = nn.Linear(32, 10, bias=False) + self.layer_3.weight = self.layer_1.weight + + def forward(self, x): + x = self.layer_1(x) + x = self.layer_2(x) + x = self.layer_3(x) + return x + + def on_post_move_to_device(self): + self.layer_3.weight = self.layer_1.weight + + model = WeightSharingModule() + trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1) + + with pytest.warns(UserWarning) as warnings: + result = trainer.fit(model) + assert result + + assert not list(filter(lambda x: 'The model parameters do not match' in str(x), warnings.list)) From bc7910b24292350e301589ac63ecd2f705259848 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Sun, 10 Jan 2021 16:44:39 +0530 Subject: [PATCH 03/27] docs and refactors --- docs/source/tpu.rst | 54 +++++++++++++++++++++++++++- pytorch_lightning/core/decorators.py | 17 +++++++++ tests/backends/test_tpu_backend.py | 34 ++---------------- tests/base/weight_sharing_module.py | 18 ++++++++++ 4 files changed, 91 insertions(+), 32 deletions(-) create mode 100644 tests/base/weight_sharing_module.py diff --git a/docs/source/tpu.rst b/docs/source/tpu.rst index 549a3a1cd25d6..3648c4630e838 100644 --- a/docs/source/tpu.rst +++ b/docs/source/tpu.rst @@ -191,7 +191,59 @@ set the 16-bit flag. Under the hood the xla library will use the `bfloat16 type `_. ----------------- + +----------------- + +Weight Sharing/Tying +----------------------- +Weight Tying/Sharing is a technique where in the module weights are shared among two or more layers. +This is a common method to reduce memory consumption and is utilized in many State of the Art +architectures today. + +PyTorch XLA requires these weights to be tied/shared after moving the model +to the TPU device. To support this requirement Lightning provides a model hook which is +called after the model is moved to the device. Any weights that require to be tied should +be done in the `on_post_move_to_device` model hook. This will ensure that the weights +among the modules are shared and not copied. + +PyTorch Lightning has an inbuilt check which verifies that the model parameter lengths +match once the model is moved to the device. If the lengths do not match Lightning +throws a warning message. + +Example: + +.. code-block:: python + + import pytorch_lightning as pl + from torch import nn + + + class WeightSharingModule(pl.LightningModule): + def __init__(self): + super().__init__() + self.layer_1 = nn.Linear(32, 10, bias=False) + self.layer_2 = nn.Linear(10, 32, bias=False) + self.layer_3 = nn.Linear(32, 10, bias=False) + self.layer_3.weight = self.layer_1.weight # Weights will be copied on TPU + + def forward(self, x): + x = self.layer_1(x) + x = self.layer_2(x) + x = self.layer_3(x) + return x + + def on_post_move_to_device(self): + # Weights shared after the model has been moved to TPU Device + self.layer_3.weight = self.layer_1.weight + + + model = WeightSharingModule() + trainer = Trainer(max_epochs=1, tpu_cores=8) + result = trainer.fit(model) + +See `XLA Documentation `_ + +----------------------- Performance considerations -------------------------- diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index af9d2e3bb7323..d63b390eb1484 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -68,6 +68,23 @@ def auto_transfer_args(self, *args, **kwargs): def parameter_validation(fn: Callable) -> Callable: + """ + Decorator for `~pytorch_lightning.core.LightningModule.to` method. + Validates that the module parameter lengths match after moving to the device. It is useful + when tying weights on TPU's. + + Args: + fn: `.to` method + + Note: + TPU's require weights to be tied/shared after moving the module to the device. + Failure to do this results in the initialization of new weights which are not tied. + To overcome this issue, weights should be tied using the `on_post_move_to_device` model hook + which is called after the module has been moved to the device. + + See Also: + - `XLA Documentation `_ + """ @wraps(fn) def inner_f(self, *args, **kwargs): pre_param_count = len(list(self.parameters())) diff --git a/tests/backends/test_tpu_backend.py b/tests/backends/test_tpu_backend.py index 954deb9c4cdb5..8d6a4a5d13024 100644 --- a/tests/backends/test_tpu_backend.py +++ b/tests/backends/test_tpu_backend.py @@ -14,14 +14,13 @@ import pytest import torch -from torch import nn from pytorch_lightning import Trainer from pytorch_lightning.trainer.states import TrainerState -from tests.base import SimpleModule from pytorch_lightning.utilities.xla_device import XLADeviceUtils from tests.base.boring_model import BoringModel from tests.base.develop_utils import pl_multi_process_test +from tests.base.weight_sharing_module import WeightSharingModule @pytest.mark.skipif(not XLADeviceUtils.tpu_device_exists(), reason="test requires TPU machine") @@ -73,20 +72,6 @@ def test_weight_tying_warning(tmpdir, capsys=None): post moving to device. """ - class WeightSharingModule(SimpleModule): - def __init__(self): - super().__init__() - self.layer_1 = nn.Linear(32, 10, bias=False) - self.layer_2 = nn.Linear(10, 32, bias=False) - self.layer_3 = nn.Linear(32, 10, bias=False) - self.layer_3.weight = self.layer_1.weight - - def forward(self, x): - x = self.layer_1(x) - x = self.layer_2(x) - x = self.layer_3(x) - return x - model = WeightSharingModule() trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1) @@ -103,24 +88,11 @@ def test_if_weights_tied(tmpdir, capsys=None): Ensure no warning for parameter mismatch is thrown. """ - class WeightSharingModule(SimpleModule): - def __init__(self): - super().__init__() - self.layer_1 = nn.Linear(32, 10, bias=False) - self.layer_2 = nn.Linear(10, 32, bias=False) - self.layer_3 = nn.Linear(32, 10, bias=False) - self.layer_3.weight = self.layer_1.weight - - def forward(self, x): - x = self.layer_1(x) - x = self.layer_2(x) - x = self.layer_3(x) - return x - + class Model(WeightSharingModule): def on_post_move_to_device(self): self.layer_3.weight = self.layer_1.weight - model = WeightSharingModule() + model = Model() trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1) with pytest.warns(UserWarning) as warnings: diff --git a/tests/base/weight_sharing_module.py b/tests/base/weight_sharing_module.py new file mode 100644 index 0000000000000..924b5233aa7cb --- /dev/null +++ b/tests/base/weight_sharing_module.py @@ -0,0 +1,18 @@ +from torch import nn + +from tests.base import SimpleModule + + +class WeightSharingModule(SimpleModule): + def __init__(self): + super().__init__() + self.layer_1 = nn.Linear(32, 10, bias=False) + self.layer_2 = nn.Linear(10, 32, bias=False) + self.layer_3 = nn.Linear(32, 10, bias=False) + self.layer_3.weight = self.layer_1.weight + + def forward(self, x): + x = self.layer_1(x) + x = self.layer_2(x) + x = self.layer_3(x) + return x From bb4c89151d8c83a8272d890a63e33c1fbf5b1b66 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Tue, 12 Jan 2021 22:04:01 +0530 Subject: [PATCH 04/27] Update tests/backends/test_tpu_backend.py Co-authored-by: Jirka Borovec --- tests/backends/test_tpu_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/backends/test_tpu_backend.py b/tests/backends/test_tpu_backend.py index 8d6a4a5d13024..fdd45435bcb67 100644 --- a/tests/backends/test_tpu_backend.py +++ b/tests/backends/test_tpu_backend.py @@ -64,7 +64,7 @@ def test_if_test_works_after_train(tmpdir): assert trainer.test() == 1 -@pytest.mark.skipif(not XLADeviceUtils.tpu_device_exists(), reason="test requires TPU machine") +@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") @pl_multi_process_test def test_weight_tying_warning(tmpdir, capsys=None): """ From 9897945ec9031428d7c404d73ebd94609bb9ac33 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Tue, 12 Jan 2021 22:04:22 +0530 Subject: [PATCH 05/27] Update docs/source/tpu.rst Co-authored-by: Jirka Borovec --- docs/source/tpu.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/tpu.rst b/docs/source/tpu.rst index 3648c4630e838..0f9e7dd13ec46 100644 --- a/docs/source/tpu.rst +++ b/docs/source/tpu.rst @@ -212,7 +212,7 @@ throws a warning message. Example: -.. code-block:: python +.. testcode:: import pytorch_lightning as pl from torch import nn From 0528db835145960adb09d9c9c0c47b91e813214e Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Tue, 12 Jan 2021 22:04:36 +0530 Subject: [PATCH 06/27] Update docs/source/tpu.rst Co-authored-by: Jirka Borovec --- docs/source/tpu.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/source/tpu.rst b/docs/source/tpu.rst index 0f9e7dd13ec46..d906a61df5ef8 100644 --- a/docs/source/tpu.rst +++ b/docs/source/tpu.rst @@ -239,7 +239,6 @@ Example: model = WeightSharingModule() trainer = Trainer(max_epochs=1, tpu_cores=8) - result = trainer.fit(model) See `XLA Documentation `_ From c41e4ac73c16c9a0632dd7c3530cce0a8909a488 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Tue, 12 Jan 2021 22:06:33 +0530 Subject: [PATCH 07/27] Update pytorch_lightning/core/decorators.py Co-authored-by: Jirka Borovec --- pytorch_lightning/core/decorators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index d63b390eb1484..ccacc82506a1b 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -86,7 +86,7 @@ def parameter_validation(fn: Callable) -> Callable: - `XLA Documentation `_ """ @wraps(fn) - def inner_f(self, *args, **kwargs): + def inner_fn(self, *args, **kwargs): pre_param_count = len(list(self.parameters())) module = fn(self, *args, **kwargs) self.on_post_move_to_device() From c7866ff4433c8756288873565978fc2be74cbcdd Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Tue, 12 Jan 2021 22:07:16 +0530 Subject: [PATCH 08/27] Update pytorch_lightning/core/decorators.py Co-authored-by: Jirka Borovec --- pytorch_lightning/core/decorators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index ccacc82506a1b..c78ba6dcfb61f 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -99,4 +99,4 @@ def inner_fn(self, *args, **kwargs): return module - return inner_f + return inner_fn From ac494d1f4910cd4f556554d6d002bbb29e2cf1fb Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Thu, 14 Jan 2021 18:57:49 +0530 Subject: [PATCH 09/27] Update docs/source/tpu.rst Co-authored-by: Rohit Gupta --- docs/source/tpu.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/tpu.rst b/docs/source/tpu.rst index d906a61df5ef8..fb130b9a5f9e1 100644 --- a/docs/source/tpu.rst +++ b/docs/source/tpu.rst @@ -195,7 +195,7 @@ Under the hood the xla library will use the `bfloat16 type Date: Thu, 14 Jan 2021 19:00:01 +0530 Subject: [PATCH 10/27] Update pytorch_lightning/core/decorators.py Co-authored-by: Rohit Gupta --- pytorch_lightning/core/decorators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index c78ba6dcfb61f..8da6e9d34e324 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -69,7 +69,7 @@ def auto_transfer_args(self, *args, **kwargs): def parameter_validation(fn: Callable) -> Callable: """ - Decorator for `~pytorch_lightning.core.LightningModule.to` method. + Decorator for :meth:`~pytorch_lightning.core.LightningModule.to` method. Validates that the module parameter lengths match after moving to the device. It is useful when tying weights on TPU's. From e1126b2de3516eee1cfd0be4d2fb6181fcb7e9f2 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Thu, 14 Jan 2021 19:00:16 +0530 Subject: [PATCH 11/27] Update pytorch_lightning/core/decorators.py Co-authored-by: Rohit Gupta --- pytorch_lightning/core/decorators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index 8da6e9d34e324..39d112cb4a321 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -74,7 +74,7 @@ def parameter_validation(fn: Callable) -> Callable: when tying weights on TPU's. Args: - fn: `.to` method + fn: ``.to`` method Note: TPU's require weights to be tied/shared after moving the module to the device. From 975b8999f395e4cd5aae8e06558f686c5c99c844 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Thu, 14 Jan 2021 19:00:35 +0530 Subject: [PATCH 12/27] Update pytorch_lightning/core/decorators.py Co-authored-by: Rohit Gupta --- pytorch_lightning/core/decorators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index 39d112cb4a321..d5a6469d75629 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -79,7 +79,7 @@ def parameter_validation(fn: Callable) -> Callable: Note: TPU's require weights to be tied/shared after moving the module to the device. Failure to do this results in the initialization of new weights which are not tied. - To overcome this issue, weights should be tied using the `on_post_move_to_device` model hook + To overcome this issue, weights should be tied using the ``on_post_move_to_device`` model hook which is called after the module has been moved to the device. See Also: From 814b163b39c6296762c726e191e96f1f1effde6c Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Thu, 14 Jan 2021 19:01:21 +0530 Subject: [PATCH 13/27] Update pytorch_lightning/core/decorators.py Co-authored-by: Rohit Gupta --- pytorch_lightning/core/decorators.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index d5a6469d75629..37e77280101bf 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -93,9 +93,9 @@ def inner_fn(self, *args, **kwargs): post_param_count = len(list(self.parameters())) if not pre_param_count == post_param_count: - rank_zero_warn('The model parameters do not match after moving to the target device. ' - 'If your model employs weight sharing on TPU,' - 'please tie your weights using the `on_post_move_to_device` model hook.') + rank_zero_warn('The model parameters do not match after moving to the target device.' + ' If your model employs weight sharing on TPU,' + ' please tie your weights using the `on_post_move_to_device` model hook.') return module From e69910a7fc45108f318c8622011641115d7da46a Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Thu, 14 Jan 2021 19:02:09 +0530 Subject: [PATCH 14/27] Update pytorch_lightning/core/hooks.py Co-authored-by: Rohit Gupta --- pytorch_lightning/core/hooks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 00f151924e704..c13ff90dc33f4 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -316,9 +316,9 @@ def on_after_backward(self): def on_post_move_to_device(self) -> None: """ - Called in the parameter_validation decorator after Lightning.to is called - This is a good place to tie weights between modules after moving them to a device. - Can be used when training models with weight sharing properties on TPU. + Called in the ``parameter_validation`` decorator after :meth:`~pytorch_lightning.core.LightningModule.to` + is called. This is a good place to tie weights between modules after moving them to a device. Can be + used when training models with weight sharing properties on TPU. Addresses the handling of shared weights on TPU: https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks From 8a22096bd04588eb20dc1e6a211102d842c7bc85 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Sat, 16 Jan 2021 15:21:18 +0530 Subject: [PATCH 15/27] moved weight sharing module back to test updated tpu available --- tests/backends/test_tpu_backend.py | 26 +++++++++++++++++++++----- tests/base/weight_sharing_module.py | 18 ------------------ 2 files changed, 21 insertions(+), 23 deletions(-) delete mode 100644 tests/base/weight_sharing_module.py diff --git a/tests/backends/test_tpu_backend.py b/tests/backends/test_tpu_backend.py index fdd45435bcb67..64d124250026e 100644 --- a/tests/backends/test_tpu_backend.py +++ b/tests/backends/test_tpu_backend.py @@ -14,16 +14,32 @@ import pytest import torch +from torch import nn from pytorch_lightning import Trainer from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities.xla_device import XLADeviceUtils +from pytorch_lightning.utilities import _TPU_AVAILABLE +from tests.base import SimpleModule from tests.base.boring_model import BoringModel from tests.base.develop_utils import pl_multi_process_test -from tests.base.weight_sharing_module import WeightSharingModule -@pytest.mark.skipif(not XLADeviceUtils.tpu_device_exists(), reason="test requires TPU machine") +class WeightSharingModule(SimpleModule): + def __init__(self): + super().__init__() + self.layer_1 = nn.Linear(32, 10, bias=False) + self.layer_2 = nn.Linear(10, 32, bias=False) + self.layer_3 = nn.Linear(32, 10, bias=False) + self.layer_3.weight = self.layer_1.weight + + def forward(self, x): + x = self.layer_1(x) + x = self.layer_2(x) + x = self.layer_3(x) + return x + + +@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") @pl_multi_process_test def test_resume_training_on_cpu(tmpdir): """ Checks if training can be resumed from a saved checkpoint on CPU""" @@ -51,7 +67,7 @@ def test_resume_training_on_cpu(tmpdir): assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" -@pytest.mark.skipif(not XLADeviceUtils.tpu_device_exists(), reason="test requires TPU machine") +@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") @pl_multi_process_test def test_if_test_works_after_train(tmpdir): """ Ensure that .test() works after .fit() """ @@ -80,7 +96,7 @@ def test_weight_tying_warning(tmpdir, capsys=None): assert result -@pytest.mark.skipif(not XLADeviceUtils.tpu_device_exists(), reason="test requires TPU machine") +@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") @pl_multi_process_test def test_if_weights_tied(tmpdir, capsys=None): """ diff --git a/tests/base/weight_sharing_module.py b/tests/base/weight_sharing_module.py deleted file mode 100644 index 924b5233aa7cb..0000000000000 --- a/tests/base/weight_sharing_module.py +++ /dev/null @@ -1,18 +0,0 @@ -from torch import nn - -from tests.base import SimpleModule - - -class WeightSharingModule(SimpleModule): - def __init__(self): - super().__init__() - self.layer_1 = nn.Linear(32, 10, bias=False) - self.layer_2 = nn.Linear(10, 32, bias=False) - self.layer_3 = nn.Linear(32, 10, bias=False) - self.layer_3.weight = self.layer_1.weight - - def forward(self, x): - x = self.layer_1(x) - x = self.layer_2(x) - x = self.layer_3(x) - return x From bf5349b0ec137ef219330b531cc11fff41794b69 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Sat, 16 Jan 2021 17:48:58 +0530 Subject: [PATCH 16/27] add count to warning --- pytorch_lightning/core/decorators.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index 37e77280101bf..78151bb0db350 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -93,9 +93,10 @@ def inner_fn(self, *args, **kwargs): post_param_count = len(list(self.parameters())) if not pre_param_count == post_param_count: - rank_zero_warn('The model parameters do not match after moving to the target device.' + rank_zero_warn(f'The model parameters do not match after moving to the target device.' ' If your model employs weight sharing on TPU,' - ' please tie your weights using the `on_post_move_to_device` model hook.') + ' please tie your weights using the `on_post_move_to_device` model hook.' + f'Parameter count: [Before: {pre_param_count} After: {post_param_count}]') return module From 35e67c00832aff63b869e9a40139b644d214e447 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Sat, 16 Jan 2021 19:04:30 +0530 Subject: [PATCH 17/27] fix doctest --- docs/source/tpu.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/tpu.rst b/docs/source/tpu.rst index fb130b9a5f9e1..326688c59aeb5 100644 --- a/docs/source/tpu.rst +++ b/docs/source/tpu.rst @@ -214,11 +214,11 @@ Example: .. testcode:: - import pytorch_lightning as pl + from pytorch_lightning.core.lightning import LightningModule from torch import nn - class WeightSharingModule(pl.LightningModule): + class WeightSharingModule(LightningModule): def __init__(self): super().__init__() self.layer_1 = nn.Linear(32, 10, bias=False) From 829c04154a1b8ca35a7f90d86b702821a0cc2303 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Sat, 16 Jan 2021 19:25:38 +0530 Subject: [PATCH 18/27] import trainer in doctest --- docs/source/tpu.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/tpu.rst b/docs/source/tpu.rst index 326688c59aeb5..8115b1aacc939 100644 --- a/docs/source/tpu.rst +++ b/docs/source/tpu.rst @@ -216,6 +216,7 @@ Example: from pytorch_lightning.core.lightning import LightningModule from torch import nn + from pytorch_lightning import Trainer class WeightSharingModule(LightningModule): From a1e574f3872f80688caf813c558b2d6eded54df4 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Mon, 18 Jan 2021 11:05:54 +0530 Subject: [PATCH 19/27] import trainer in doctest --- docs/source/tpu.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/tpu.rst b/docs/source/tpu.rst index 8115b1aacc939..ad4138cc53d9c 100644 --- a/docs/source/tpu.rst +++ b/docs/source/tpu.rst @@ -216,7 +216,7 @@ Example: from pytorch_lightning.core.lightning import LightningModule from torch import nn - from pytorch_lightning import Trainer + from pytorch_lightning.trainer.trainer import Trainer class WeightSharingModule(LightningModule): From ac71a15d0b356ab84dc3703f1d0fa0b9f6ff4bd4 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Tue, 19 Jan 2021 07:37:45 +0530 Subject: [PATCH 20/27] do not test code as no TPU device --- docs/source/tpu.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/tpu.rst b/docs/source/tpu.rst index ad4138cc53d9c..33a0177b3f17e 100644 --- a/docs/source/tpu.rst +++ b/docs/source/tpu.rst @@ -212,7 +212,7 @@ throws a warning message. Example: -.. testcode:: +.. code-block:: python from pytorch_lightning.core.lightning import LightningModule from torch import nn From 432c3c3801a701f9b8974d07315cc87f617ee739 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Wed, 20 Jan 2021 18:58:46 +0530 Subject: [PATCH 21/27] param count to layer count --- pytorch_lightning/core/decorators.py | 10 +++++----- tests/backends/test_tpu_backend.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index 78151bb0db350..317139107bd88 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -87,16 +87,16 @@ def parameter_validation(fn: Callable) -> Callable: """ @wraps(fn) def inner_fn(self, *args, **kwargs): - pre_param_count = len(list(self.parameters())) + pre_layer_count = len(list(self.parameters())) module = fn(self, *args, **kwargs) self.on_post_move_to_device() - post_param_count = len(list(self.parameters())) + post_layer_count = len(list(self.parameters())) - if not pre_param_count == post_param_count: - rank_zero_warn(f'The model parameters do not match after moving to the target device.' + if not pre_layer_count == post_layer_count: + rank_zero_warn(f'The model layers do not match after moving to the target device.' ' If your model employs weight sharing on TPU,' ' please tie your weights using the `on_post_move_to_device` model hook.' - f'Parameter count: [Before: {pre_param_count} After: {post_param_count}]') + f'Layer count: [Before: {pre_layer_count} After: {post_layer_count}]') return module diff --git a/tests/backends/test_tpu_backend.py b/tests/backends/test_tpu_backend.py index 64d124250026e..79c1318766595 100644 --- a/tests/backends/test_tpu_backend.py +++ b/tests/backends/test_tpu_backend.py @@ -91,7 +91,7 @@ def test_weight_tying_warning(tmpdir, capsys=None): model = WeightSharingModule() trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1) - with pytest.warns(UserWarning, match=r'The model parameters do not match after moving to the target device.'): + with pytest.warns(UserWarning, match=r'The model layers do not match after moving to the target device.'): result = trainer.fit(model) assert result @@ -115,4 +115,4 @@ def on_post_move_to_device(self): result = trainer.fit(model) assert result - assert not list(filter(lambda x: 'The model parameters do not match' in str(x), warnings.list)) + assert not list(filter(lambda x: 'The model layers do not match' in str(x), warnings.list)) From beb1ab43038b2e1daf520b30b82133d263076d9a Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 28 Jan 2021 10:44:17 +0100 Subject: [PATCH 22/27] formatting --- pytorch_lightning/core/decorators.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index 317139107bd88..55d6bb9db1ed4 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -93,10 +93,12 @@ def inner_fn(self, *args, **kwargs): post_layer_count = len(list(self.parameters())) if not pre_layer_count == post_layer_count: - rank_zero_warn(f'The model layers do not match after moving to the target device.' - ' If your model employs weight sharing on TPU,' - ' please tie your weights using the `on_post_move_to_device` model hook.' - f'Layer count: [Before: {pre_layer_count} After: {post_layer_count}]') + rank_zero_warn( + f'The model layers do not match after moving to the target device.' + ' If your model employs weight sharing on TPU,' + ' please tie your weights using the `on_post_move_to_device` model hook.\n' + f'Layer count: [Before: {pre_layer_count} After: {post_layer_count}]' + ) return module From 4cf6799a4627e455242f6308fb22c0c7fd7699da Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Mon, 1 Feb 2021 18:48:14 +0530 Subject: [PATCH 23/27] update docs --- docs/source/advanced/tpu.rst | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/source/advanced/tpu.rst b/docs/source/advanced/tpu.rst index 33a0177b3f17e..ff938f4aef9e9 100644 --- a/docs/source/advanced/tpu.rst +++ b/docs/source/advanced/tpu.rst @@ -225,7 +225,10 @@ Example: self.layer_1 = nn.Linear(32, 10, bias=False) self.layer_2 = nn.Linear(10, 32, bias=False) self.layer_3 = nn.Linear(32, 10, bias=False) - self.layer_3.weight = self.layer_1.weight # Weights will be copied on TPU + # TPU shared weights are copied independently + # on the XLA device and this line won't have any effect. + # However, it works fine for CPU and GPU. + self.layer_3.weight = self.layer_1.weight def forward(self, x): x = self.layer_1(x) From 2268df8b6f172ac97c281d529d0de901484ed52d Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 17 Feb 2021 20:19:34 +0000 Subject: [PATCH 24/27] update import --- pytorch_lightning/accelerators/legacy/tpu_accelerator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/legacy/tpu_accelerator.py b/pytorch_lightning/accelerators/legacy/tpu_accelerator.py index 25b0fff1896ba..7a12eec055eaa 100644 --- a/pytorch_lightning/accelerators/legacy/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/legacy/tpu_accelerator.py @@ -21,7 +21,7 @@ from torch.optim import Optimizer from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.legacy.accelerator import Accelerator, ReduceOp +from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.core import LightningModule from pytorch_lightning.plugins.environments import ClusterEnvironment from pytorch_lightning.utilities import ( @@ -32,6 +32,7 @@ rank_zero_warn, ) from pytorch_lightning.utilities.cloud_io import atomic_save +from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException if _TPU_AVAILABLE: From 8acfd39f97425880fdcd85dcea72aa4fa0984b5a Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 17 Feb 2021 20:49:21 +0000 Subject: [PATCH 25/27] update --- tests/accelerators/test_tpu_backend.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/accelerators/test_tpu_backend.py b/tests/accelerators/test_tpu_backend.py index 3e0843c9824a6..706cb0a2b4b08 100644 --- a/tests/accelerators/test_tpu_backend.py +++ b/tests/accelerators/test_tpu_backend.py @@ -19,12 +19,11 @@ from pytorch_lightning import Trainer from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _TPU_AVAILABLE -from tests.base import SimpleModule -from tests.base.boring_model import BoringModel -from tests.base.develop_utils import pl_multi_process_test +from tests.helpers.boring_model import BoringModel +from tests.helpers.utils import pl_multi_process_test -class WeightSharingModule(SimpleModule): +class WeightSharingModule(BoringModel): def __init__(self): super().__init__() From 62042c6f02efc93013df947ecd9065d802040118 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 17 Feb 2021 23:07:25 +0000 Subject: [PATCH 26/27] resolve tests --- tests/accelerators/test_tpu_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/accelerators/test_tpu_backend.py b/tests/accelerators/test_tpu_backend.py index 706cb0a2b4b08..daea22968b253 100644 --- a/tests/accelerators/test_tpu_backend.py +++ b/tests/accelerators/test_tpu_backend.py @@ -79,7 +79,7 @@ def test_if_test_works_after_train(tmpdir): model = BoringModel() trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model) - assert trainer.test() == 1 + assert trainer.test(model) == 1 @pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") From 3605c0bbff158e277dbc82104b91a55f95b97a0c Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 17 Feb 2021 23:41:54 +0000 Subject: [PATCH 27/27] remove legacy accelerator --- .../accelerators/legacy/tpu_accelerator.py | 370 ------------------ 1 file changed, 370 deletions(-) delete mode 100644 pytorch_lightning/accelerators/legacy/tpu_accelerator.py diff --git a/pytorch_lightning/accelerators/legacy/tpu_accelerator.py b/pytorch_lightning/accelerators/legacy/tpu_accelerator.py deleted file mode 100644 index 7a12eec055eaa..0000000000000 --- a/pytorch_lightning/accelerators/legacy/tpu_accelerator.py +++ /dev/null @@ -1,370 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import io -import os -import re -from typing import Any, Callable, Optional, Union - -import torch -import torch.multiprocessing as mp -from torch.optim import Optimizer - -from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.core import LightningModule -from pytorch_lightning.plugins.environments import ClusterEnvironment -from pytorch_lightning.utilities import ( - _TPU_AVAILABLE, - move_data_to_device, - rank_zero_info, - rank_zero_only, - rank_zero_warn, -) -from pytorch_lightning.utilities.cloud_io import atomic_save -from pytorch_lightning.utilities.distributed import ReduceOp -from pytorch_lightning.utilities.exceptions import MisconfigurationException - -if _TPU_AVAILABLE: - import torch_xla - import torch_xla.core.xla_model as xm - import torch_xla.distributed.parallel_loader as xla_pl - import torch_xla.distributed.xla_multiprocessing as xmp - - -class TPUAccelerator(Accelerator): - - def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None): - """ - Runs training using TPUs (colab, single machine or pod) - - Example:: - - # default - trainer = Trainer(accelerator=TPUAccelerator()) - - """ - super().__init__(trainer, cluster_environment) - self.start_method = None - self.mp_queue = None - self.nickname = None - - def setup(self, model): - rank_zero_info(f'training on {self.trainer.tpu_cores} TPU cores') - - # TODO: Move this check to Trainer __init__ or device parser - if not _TPU_AVAILABLE: - raise MisconfigurationException('PyTorch XLA not installed.') - - # see: https://discuss.pytorch.org/t/segfault-with-multiprocessing-queue/81292/2 - self.start_method = 'fork' - - # pass in a state q - smp = mp.get_context(self.start_method) - self.mp_queue = smp.SimpleQueue() - - self.trainer.model = model - - def teardown(self): - model = self.trainer.model - - # restore main state with best weights - best_path = self.mp_queue.get() - results = self.mp_queue.get() - last_path = self.mp_queue.get() - - # transfer back the best path to the trainer - if self.trainer.checkpoint_callback is not None: - self.trainer.checkpoint_callback.best_model_path = best_path - # todo, pass also bets score - - # load last weights - if last_path and not self.trainer.testing: - ckpt = torch.load(last_path, map_location=lambda storage, loc: storage) - model.load_state_dict(ckpt) - - self.trainer.model = model - - # when training completes, load the weights back in main process - self.__load_weights_on_main_process() - return results - - def train(self): - model = self.trainer.model - - # train - if self.trainer.tpu_id is not None: - self.tpu_train_in_process(self.trainer.tpu_id, model, self.trainer, self.mp_queue) - else: - xmp.spawn( - self.tpu_train_in_process, - args=(model, self.trainer, self.mp_queue), - nprocs=self.trainer.tpu_cores, - start_method=self.start_method - ) - - def __load_weights_on_main_process(self): - model = self.trainer.model - - # load weights if not interrupted - if self.trainer.on_colab_kaggle and not self.trainer.testing: - self.load_spawn_weights(model) - - self.trainer.model = model - - def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, trainer=None, mp_queue=None): - """ - Here we are inside each individual process - """ - # Todo: required argument `tpu_core_idx` is not used - if not trainer: - trainer = self.trainer - - trainer.call_setup_hook(model) - - # setup TPU training - self.__setup_tpu_training(model, trainer) - - # set up training routine - self.trainer.train_loop.setup_training(model) - - # train or test - results = self.train_or_test() - - # save weights at the end of training - self.__save_end_of_training_weights(model, trainer) - - # persist info in spawn - self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results) - - def _step(self, model_step: Callable, args): - args[0] = self.to_device(args[0]) - return model_step(*args) - - def training_step(self, args): - return self._step(self.trainer.model.training_step, args) - - def validation_step(self, args): - return self._step(self.trainer.model.validation_step, args) - - def test_step(self, args): - return self._step(self.trainer.model.test_step, args) - - def predict(self, args): - return self._step(self.trainer.model.predict, args) - - def process_dataloader(self, dataloader): - device = xm.xla_device(self.trainer.tpu_id) - dataloader = xla_pl.ParallelLoader(dataloader, [device]) - dataloader = dataloader.per_device_loader(device) - return dataloader - - def to_device(self, batch): - """ - Transfers the data to the TPU. - - Args: - batch: A tensor or collection of tensors. - - Return: - the tensor on the TPU device. - - See Also: - - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device` - """ - if not _TPU_AVAILABLE: - raise MisconfigurationException( - 'Requested to transfer batch to TPU but XLA is not available.' - ' Are you sure this machine has TPUs?' - ) - device = xm.xla_device(self.trainer.tpu_id) - - return self.batch_to_device(batch, device) - - def __save_end_of_training_weights(self, model: LightningModule, trainer): - # when training ends on these platforms dump weights to get out of the main process - if trainer.on_colab_kaggle: - rank_zero_warn('cleaning up... please do not interrupt') - self.save_spawn_weights(model) - - def __setup_tpu_training(self, model: LightningModule, trainer): - # get the appropriate tpu ranks - trainer.tpu_local_core_rank = xm.get_local_ordinal() - trainer.tpu_global_core_rank = xm.get_ordinal() - - # avoid duplicating progress bar - if trainer.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None: - trainer.progress_bar_callback.disable() - - trainer.global_rank = trainer.tpu_local_core_rank - rank_zero_only.rank = trainer.global_rank - - # if given an ordinal device, use this as the device - if trainer.tpu_id is not None: - tpu_device = xm.xla_device(trainer.tpu_id) - else: - tpu_device = xm.xla_device() - # track the device and move model to it - trainer._device = tpu_device - model.to(trainer._device) - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.setup_optimizers(model) - - # init 16 bit for TPU - if trainer.precision == 16: - os.environ['XLA_USE_BF16'] = str(1) - - log.info( - f'INIT TPU local core: {trainer.tpu_local_core_rank},' - f' global rank: {trainer.tpu_global_core_rank}' - f' with XLA_USE_BF16={os.environ.get("XLA_USE_BF16")}' - ) - - def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): - # do backward pass - if self.trainer.train_loop.automatic_optimization: - model = self.trainer.get_model() - model.backward(closure_loss, optimizer, opt_idx) - else: - closure_loss.backward(*args, **kwargs) - - # detach after backward - closure_loss = closure_loss.detach() - - return closure_loss - - def _clip_gradients(self, optimizer: Optimizer, grad_clip_val: Union[float, int], norm_type: float = 2.0): - # this code is a modification of torch.nn.utils.clip_grad_norm_ - # with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md - model = self.trainer.get_model() - parameters = model.parameters() - max_norm = grad_clip_val - - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - parameters = list(filter(lambda p: p.grad is not None, parameters)) - - device = parameters[0].device - out = torch.empty(len(parameters), device=device) - for i, p in enumerate(parameters): - torch.norm(p.grad.data.to(device), norm_type, out=out[i]) - total_norm = torch.norm(out, norm_type) - - clip_coef = torch.tensor(max_norm, device=device) / (total_norm + self.norm_clipping_epsilon) - clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef)) - for p in parameters: - p.grad.data.mul_(clip_coef.to(p.grad.data.device)) - - def barrier(self, name: Optional[str] = None): - torch_xla.core.xla_model.rendezvous(f"pl.Trainer.{name}") - - def early_stopping_should_stop(self, pl_module): - stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device, dtype=torch.int32) - stop = xm.mesh_reduce("stop_signal", stop, sum) - torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check") - should_stop = int(stop.item()) == self.trainer.world_size - return should_stop - - def save_spawn_weights(self, model): - """ - Dump a temporary checkpoint after ddp ends to get weights out of the process - """ - # Todo: required argument `model` is not used - if self.trainer.is_global_zero: - path = os.path.join(self.trainer.default_root_dir, '__temp_weight_distributed_end.ckpt') - self.trainer.save_checkpoint(path) - return path - - def load_spawn_weights(self, original_model): - """ - Load the temp weights saved in the process - To recover the trained model from the ddp process we load the saved weights - """ - - loaded_model = original_model - - if self.trainer.is_global_zero: - # load weights saved in ddp - path = os.path.join(self.trainer.default_root_dir, '__temp_weight_distributed_end.ckpt') - loaded_model = original_model.__class__.load_from_checkpoint(path) - - # copy loaded weights to old model - original_model.load_state_dict(loaded_model.state_dict()) - - # remove ddp weights - os.remove(path) - - return loaded_model - - def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): - if self.trainer.distributed_backend not in ("ddp_spawn", "ddp_cpu", "tpu"): - return - - # track the best model path - best_model_path = None - if self.trainer.checkpoint_callback is not None: - best_model_path = self.trainer.checkpoint_callback.best_model_path - - if self.trainer.global_rank == 0 and mp_queue is not None: - rank_zero_warn('cleaning up ddp environment...') - # todo, pass complete checkpoint as state dictionary - mp_queue.put(best_model_path) - mp_queue.put(results) - - # save the last weights - last_path = None - if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path) - state_dict = move_data_to_device(model.state_dict(), torch.device("cpu")) - atomic_save(state_dict, last_path) - mp_queue.put(last_path) - - def broadcast(self, obj, src=0): - buffer = io.BytesIO() - torch.save(obj, buffer) - data = bytearray(buffer.getbuffer()) - data_tensor = torch.tensor(data).to(xm.xla_device(), dtype=torch.float) - data = xm.all_gather(data_tensor) - buffer = io.BytesIO(data.cpu().byte().numpy()) - obj = torch.load(buffer) - return obj - - def sync_tensor( - self, - tensor: Union[torch.Tensor], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = None - ) -> torch.Tensor: - return tensor - - @property - def norm_clipping_epsilon(self): - return 1e-6 - - def on_save(self, checkpoint): - """ - Move XLA tensors to CPU before saving - Recommended on XLA Guide: - https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors - """ - return move_data_to_device(checkpoint, torch.device("cpu")) - - @property - def distributed_sampler_kwargs(self): - return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) - - @property - def require_distributed_sampler(self): - return True