Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions pytorch_lightning/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,11 @@ def auto_transfer_args(self, *args, **kwargs):

def parameter_validation(fn: Callable) -> Callable:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think now that you changed the decorator target to self.model, this decorator may no longer fit very well into core/decorators because it is basically now specific to the plugin having the attribute self.model.
What do you think about moving it?

Just for consideration

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to @awaelchli 's suggestion

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Will do a follow-up PR for this.

"""
Decorator for :meth:`~pytorch_lightning.core.LightningModule.to` method.
Validates that the module parameter lengths match after moving to the device. It is useful
when tying weights on TPU's.

Args:
fn: ``.to`` method
fn: ``model_to_device`` method

Note:
TPU's require weights to be tied/shared after moving the module to the device.
Expand All @@ -90,10 +89,10 @@ def parameter_validation(fn: Callable) -> Callable:

@wraps(fn)
def inner_fn(self, *args, **kwargs):
pre_layer_count = len(list(self.parameters()))
pre_layer_count = len(list(self.model.parameters()))
module = fn(self, *args, **kwargs)
self.on_post_move_to_device()
post_layer_count = len(list(self.parameters()))
self.model.on_post_move_to_device()
post_layer_count = len(list(self.model.parameters()))

if not pre_layer_count == post_layer_count:
rank_zero_warn(
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/plugins/training_type/single_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch

from pytorch_lightning.core.decorators import parameter_validation
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.apply_func import move_data_to_device
Expand All @@ -39,6 +40,7 @@ def __init__(self, device: int, debug: bool = False):
def is_distributed(self) -> bool:
return False

@parameter_validation
def model_to_device(self) -> None:
self.model.to(self.root_device)

Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.core.decorators import parameter_validation
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
Expand Down Expand Up @@ -171,6 +172,7 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
if self.local_rank == 0:
time.sleep(2)

@parameter_validation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how sow is this?

def model_to_device(self) -> None:
self.model = self.wrapped_model.to(self.root_device)

Expand Down
6 changes: 0 additions & 6 deletions pytorch_lightning/utilities/device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import torch
from torch.nn import Module

from pytorch_lightning.core.decorators import parameter_validation


class DeviceDtypeModuleMixin(Module):
__jit_unused_properties__ = ['device', 'dtype']
Expand Down Expand Up @@ -47,7 +45,6 @@ def device(self) -> Union[str, torch.device]:

return device

@parameter_validation
def to(self, *args, **kwargs) -> Module:
"""Moves and/or casts the parameters and buffers.

Expand Down Expand Up @@ -84,9 +81,6 @@ def to(self, *args, **kwargs) -> Module:
... def __init__(self, weight: torch.Tensor):
... super().__init__()
... self.register_buffer('weight', weight)
...
... def on_post_move_to_device(self):
... pass
>>> _ = torch.manual_seed(0)
>>> module = ExampleModule(torch.rand(3, 4))
>>> module.weight #doctest: +ELLIPSIS
Expand Down
32 changes: 14 additions & 18 deletions tests/accelerators/test_tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,25 +95,21 @@ def test_weight_tying_warning(tmpdir, capsys=None):
trainer.fit(model)


# @RunIf(tpu=True)
# @pl_multi_process_test
# def test_if_weights_tied(tmpdir, capsys=None):
# """
# Test if weights are properly tied on `on_post_move_to_device`.
# Ensure no warning for parameter mismatch is thrown.
# """

# # TODO (kaushikb11): Add `parameter_validation` specific to TPU Accelerators
# class Model(WeightSharingModule):
@RunIf(tpu=True)
@pl_multi_process_test
def test_if_weights_tied(tmpdir, capsys=None):
"""
Test if weights are properly tied on `on_post_move_to_device`.
Ensure no warning for parameter mismatch is thrown.
"""

# def on_post_move_to_device(self):
# self.layer_3.weight = self.layer_1.weight
class Model(WeightSharingModule):

# model = Model()
# trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1)
def on_post_move_to_device(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make this check slightly smarter but checking parameters names ?

If I do self.layer_3.weight = self.layer_1.weight in the init function and mess up and do self.layer_3.weight = self.layer_2.weight, I won't get a warning but tying is different. Ideally it would be great to explicitly tell which weights are shared or do it automatically for the user.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, will follow up.

self.layer_3.weight = self.layer_1.weight

# with pytest.warns(UserWarning) as warnings:
# trainer.fit(model)
model = Model()
trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1)

# assert not list(filter(lambda x: 'The model layers do not match' in str(x), warnings.list))
# assert len(trainer.test(model)) == 1
with pytest.warns(UserWarning, match="The model layers do not match"):
trainer.fit(model)