Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
dac83cf
Enable auto parameters tying for TPUs
kaushikb11 Sep 14, 2021
bc45547
Update plugins
kaushikb11 Sep 14, 2021
9644d3d
Update decorators
kaushikb11 Sep 15, 2021
d4ba103
Update tests
kaushikb11 Sep 28, 2021
7ad20c1
Update tests
kaushikb11 Sep 28, 2021
3fdb3ee
Update typing
kaushikb11 Sep 28, 2021
7a1a921
Address reviews
kaushikb11 Sep 28, 2021
a5cb251
Update pytorch_lightning/core/decorators.py
kaushikb11 Sep 29, 2021
adf3f96
Update tests/accelerators/test_tpu_backend.py
kaushikb11 Sep 29, 2021
4247483
Address deprecation concerns
kaushikb11 Sep 29, 2021
f9c196f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2021
24b4c8d
Scrap the decorator logic
kaushikb11 Sep 29, 2021
00a7de8
Update tpu spawn plugin
kaushikb11 Sep 29, 2021
aa8f7df
Update decorators
kaushikb11 Sep 29, 2021
5e4ec17
Update tests
kaushikb11 Sep 29, 2021
82bf7d3
Update tests
kaushikb11 Sep 30, 2021
8f7ecbc
Address reviews
kaushikb11 Sep 30, 2021
22d5556
Merge branch 'master' into feat/enable_auto_tying__for_tpus
kaushikb11 Sep 30, 2021
63b03b1
Update deprecation msg
kaushikb11 Sep 30, 2021
68775ed
Add ParameterSharingModule
kaushikb11 Sep 30, 2021
2dd878a
Update changelog
kaushikb11 Sep 30, 2021
78b787d
Fix mypy
kaushikb11 Sep 30, 2021
8f7e828
Fix test
kaushikb11 Sep 30, 2021
c8ec0c3
Fix for tpu spawn
kaushikb11 Sep 30, 2021
4d6435b
Update CHANGELOG.md
kaushikb11 Oct 1, 2021
4031cf9
Update pytorch_lightning/core/decorators.py
kaushikb11 Oct 1, 2021
8f02d16
Update pytorch_lightning/trainer/configuration_validator.py
kaushikb11 Oct 1, 2021
942e3a8
Update pytorch_lightning/utilities/parameter_tying.py
kaushikb11 Oct 1, 2021
4e1df7e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 1, 2021
0117bea
Address reviews
kaushikb11 Oct 1, 2021
f17cf25
Merge branch 'feat/enable_auto_tying__for_tpus' of https://github.com…
kaushikb11 Oct 1, 2021
d83d7ad
Fix test
kaushikb11 Oct 1, 2021
6ad1dd9
Update
kaushikb11 Oct 5, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for `torch.use_deterministic_algorithms` ([#9121](https://github.com/PyTorchLightning/pytorch-lightning/pull/9121))


- Enabled automatic parameters tying for TPUs ([#9525](https://github.com/PyTorchLightning/pytorch-lightning/pull/9525))


### Changed

- `pytorch_lightning.loggers.neptune.NeptuneLogger` is now consistent with new [neptune-client](https://github.com/neptune-ai/neptune-client) API ([#6867](https://github.com/PyTorchLightning/pytorch-lightning/pull/6867)).
Expand Down Expand Up @@ -296,6 +299,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated Accelerator collective API `barrier`, `broadcast`, and `all_gather`, call `TrainingTypePlugin` collective API directly ([#9677](https://github.com/PyTorchLightning/pytorch-lightning/pull/9677))


- Deprecated the `LightningModule.on_post_move_to_device` method ([#9525](https://github.com/PyTorchLightning/pytorch-lightning/pull/9525))


- Deprecated `pytorch_lightning.core.decorators.parameter_validation` in favor of `pytorch_lightning.utilities.parameter_tying.set_shared_parameters` ([#9525](https://github.com/PyTorchLightning/pytorch-lightning/pull/9525))


### Removed

- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ module = [
"pytorch_lightning.utilities.distributed",
"pytorch_lightning.utilities.memory",
"pytorch_lightning.utilities.model_summary",
"pytorch_lightning.utilities.parameter_tying",
"pytorch_lightning.utilities.parsing",
"pytorch_lightning.utilities.seed",
"pytorch_lightning.utilities.xla_device",
Expand Down
14 changes: 10 additions & 4 deletions pytorch_lightning/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Decorator for LightningModule methods."""
from pytorch_lightning.utilities import rank_zero_deprecation

from functools import wraps
from typing import Callable
rank_zero_deprecation(
"Using `pytorch_lightning.core.decorators.parameter_validation` is deprecated in v1.5, "
"and will be removed in v1.7. It has been replaced by automatic parameters tying with "
"`pytorch_lightning.utilities.params_tying.set_shared_parameters`"
)

from pytorch_lightning.utilities import rank_zero_warn
from functools import wraps # noqa: E402
from typing import Callable # noqa: E402

from pytorch_lightning.utilities import rank_zero_warn # noqa: E402


def parameter_validation(fn: Callable) -> Callable:
Expand Down
18 changes: 15 additions & 3 deletions pytorch_lightning/plugins/training_type/single_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@
import os
from typing import Any, Dict

from pytorch_lightning.core.decorators import parameter_validation
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE
from pytorch_lightning.utilities import (
_OMEGACONF_AVAILABLE,
_TPU_AVAILABLE,
find_shared_parameters,
set_shared_parameters,
)
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import _PATH

if _TPU_AVAILABLE:
Expand Down Expand Up @@ -49,7 +54,14 @@ def __init__(
def is_distributed(self) -> bool:
return False

@parameter_validation
def setup(self) -> None:
shared_params = find_shared_parameters(self.model)
self.model_to_device()
if is_overridden("on_post_move_to_device", self.lightning_module):
self.model.on_post_move_to_device()
else:
set_shared_parameters(self.model, shared_params)

def model_to_device(self) -> None:
self.model.to(self.root_device)

Expand Down
17 changes: 14 additions & 3 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,23 @@
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.core.decorators import parameter_validation
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities import (
_OMEGACONF_AVAILABLE,
_TPU_AVAILABLE,
find_shared_parameters,
rank_zero_warn,
set_shared_parameters,
)
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.data import has_len
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT

Expand Down Expand Up @@ -156,7 +162,13 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None:
trainer.progress_bar_callback.disable()

shared_params = find_shared_parameters(self.model)
self.model_to_device()
if is_overridden("on_post_move_to_device", self.lightning_module):
self.model.module.on_post_move_to_device()
else:
set_shared_parameters(self.model.module, shared_params)

trainer.accelerator.setup_optimizers(trainer)
trainer.precision_plugin.connect(self._model, None, None)

Expand All @@ -176,7 +188,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
# ensure that spawned processes go through teardown before joining
trainer._call_teardown_hook()

@parameter_validation
def model_to_device(self) -> None:
self.model = self.wrapped_model.to(self.root_device)

Expand Down
15 changes: 15 additions & 0 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def verify_loop_configurations(self, model: "pl.LightningModule") -> None:
self._check_add_get_queue(model)
# TODO(@daniellepintz): Delete _check_progress_bar in v1.7
self._check_progress_bar(model)
# TODO: Delete _check_on_post_move_to_device in v1.7
self._check_on_post_move_to_device(model)
# TODO: Delete _check_on_keyboard_interrupt in v1.7
self._check_on_keyboard_interrupt()

Expand Down Expand Up @@ -127,6 +129,19 @@ def _check_progress_bar(self, model: "pl.LightningModule") -> None:
" Please use the `ProgressBarBase.get_metrics` instead."
)

def _check_on_post_move_to_device(self, model: "pl.LightningModule") -> None:
r"""
Checks if `on_post_move_to_device` method is overriden and sends a deprecation warning.

Args:
model: The model to check the `on_post_move_to_device` method.
"""
if is_overridden("on_post_move_to_device", model):
rank_zero_deprecation(
"Method `on_post_move_to_device` has been deprecated in v1.5 and will be removed in v1.7. "
"We perform automatic parameters tying without the need of implementing `on_post_move_to_device`."
)

def __verify_eval_loop_configuration(self, model: "pl.LightningModule", stage: str) -> None:
loader_name = f"{stage}_dataloader"
step_name = "validation_step" if stage == "val" else "test_step"
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
_TPU_AVAILABLE,
_XLA_AVAILABLE,
)
from pytorch_lightning.utilities.parameter_tying import find_shared_parameters, set_shared_parameters # noqa: F401
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn # noqa: F401

Expand Down
71 changes: 71 additions & 0 deletions pytorch_lightning/utilities/parameter_tying.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for automatic parameters tying.

Reference:
https://github.com/pytorch/fairseq/blob/1f7ef9ed1e1061f8c7f88f8b94c7186834398690/fairseq/trainer.py#L110-L118
"""
from typing import Dict, List, Optional

from torch import nn
from torch.nn import Parameter


def find_shared_parameters(module: nn.Module) -> List[str]:
"""Returns a list of names of shared parameters set in the module."""
return _find_shared_parameters(module)


def _find_shared_parameters(module: nn.Module, tied_parameters: Optional[Dict] = None, prefix: str = "") -> List[str]:
if tied_parameters is None:
first_call = True
tied_parameters = {}
else:
first_call = False
for name, param in module._parameters.items():
param_prefix = prefix + ("." if prefix else "") + name
if param is None:
continue
if param not in tied_parameters:
tied_parameters[param] = []
tied_parameters[param].append(param_prefix)
for name, m in module._modules.items():
if m is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
_find_shared_parameters(m, tied_parameters, submodule_prefix)
if first_call:
return [x for x in tied_parameters.values() if len(x) > 1]


def set_shared_parameters(module: nn.Module, shared_params: list) -> nn.Module:
for shared_param in shared_params:
ref = _get_module_by_path(module, shared_param[0])
for path in shared_param[1:]:
_set_module_by_path(module, path, ref)
return module


def _get_module_by_path(module: nn.Module, path: str) -> nn.Module:
path = path.split(".")
for name in path:
module = getattr(module, name)
return module


def _set_module_by_path(module: nn.Module, path: str, value: Parameter) -> None:
path = path.split(".")
for name in path[:-1]:
module = getattr(module, name)
setattr(module, path[-1], value)
78 changes: 47 additions & 31 deletions tests/accelerators/test_tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning.accelerators.tpu import TPUAccelerator
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins import TPUSpawnPlugin
from pytorch_lightning.utilities import find_shared_parameters
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -80,37 +81,6 @@ def test_if_test_works_after_train(tmpdir):
assert len(trainer.test(model)) == 1


@RunIf(tpu=True)
@pl_multi_process_test
def test_weight_tying_warning(tmpdir, capsys=None):
"""Ensure a warning is thrown if model parameter lengths do not match post moving to device."""

model = WeightSharingModule()
trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1)

with pytest.warns(UserWarning, match=r"The model layers do not match after moving to the target device."):
trainer.fit(model)


@RunIf(tpu=True)
@pl_multi_process_test
def test_if_weights_tied(tmpdir, capsys=None):
"""Test if weights are properly tied on `on_post_move_to_device`.

Ensure no warning for parameter mismatch is thrown.
"""

class Model(WeightSharingModule):
def on_post_move_to_device(self):
self.layer_3.weight = self.layer_1.weight

model = Model()
trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1)

with pytest.warns(UserWarning, match="The model layers do not match"):
trainer.fit(model)


@RunIf(tpu=True)
def test_accelerator_tpu():

Expand Down Expand Up @@ -257,3 +227,49 @@ def test_ddp_cpu_not_supported_on_tpus():

with pytest.raises(MisconfigurationException, match="`accelerator='ddp_cpu'` is not supported on TPU machines"):
Trainer(accelerator="ddp_cpu")


@RunIf(tpu=True)
def test_auto_parameters_tying_tpus(tmpdir):

model = WeightSharingModule()
shared_params = find_shared_parameters(model)

assert shared_params[0] == ["layer_1.weight", "layer_3.weight"]

trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=5, tpu_cores=8, max_epochs=1)
trainer.fit(model)

assert torch.all(torch.eq(model.layer_1.weight, model.layer_3.weight))


@RunIf(tpu=True)
def test_auto_parameters_tying_tpus_nested_module(tmpdir):
class SubModule(nn.Module):
def __init__(self, layer):
super().__init__()
self.layer = layer

def forward(self, x):
return self.layer(x)

class NestedModule(BoringModel):
def __init__(self):
super().__init__()
self.layer = nn.Linear(32, 10, bias=False)
self.net_a = SubModule(self.layer)
self.layer_2 = nn.Linear(10, 32, bias=False)
self.net_b = SubModule(self.layer)

def forward(self, x):
x = self.net_a(x)
x = self.layer_2(x)
x = self.net_b(x)
return x

model = NestedModule()

trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=5, tpu_cores=8, max_epochs=1)
trainer.fit(model)

assert torch.all(torch.eq(model.net_a.layer.weight, model.net_b.layer.weight))
24 changes: 24 additions & 0 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,27 @@ def test_v1_7_0_deprecate_lightning_distributed(tmpdir):
from pytorch_lightning.distributed.dist import LightningDistributed

_ = LightningDistributed()


def test_v1_7_0_deprecate_on_post_move_to_device(tmpdir):
class TestModel(BoringModel):
def on_post_move_to_device(self):
print("on_post_move_to_device")

model = TestModel()

trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=5, max_epochs=1)

with pytest.deprecated_call(
match=r"Method `on_post_move_to_device` has been deprecated in v1.5 and will be removed in v1.7"
):
trainer.fit(model)


def test_v1_7_0_deprecate_parameter_validation():

_soft_unimport_module("pytorch_lightning.core.decorators")
with pytest.deprecated_call(
match="Using `pytorch_lightning.core.decorators.parameter_validation` is deprecated in v1.5"
):
from pytorch_lightning.core.decorators import parameter_validation # noqa: F401
Loading