-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Move parameter validation specific to TPU Training plugins #7415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ | |
| from torch.utils.data import DataLoader | ||
|
|
||
| import pytorch_lightning as pl | ||
| from pytorch_lightning.core.decorators import parameter_validation | ||
| from pytorch_lightning.overrides import LightningDistributedModule | ||
| from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin | ||
| from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader | ||
|
|
@@ -171,6 +172,7 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: | |
| if self.local_rank == 0: | ||
| time.sleep(2) | ||
|
|
||
| @parameter_validation | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how sow is this? |
||
| def model_to_device(self) -> None: | ||
| self.model = self.wrapped_model.to(self.root_device) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -95,25 +95,21 @@ def test_weight_tying_warning(tmpdir, capsys=None): | |
| trainer.fit(model) | ||
|
|
||
|
|
||
| # @RunIf(tpu=True) | ||
| # @pl_multi_process_test | ||
| # def test_if_weights_tied(tmpdir, capsys=None): | ||
| # """ | ||
| # Test if weights are properly tied on `on_post_move_to_device`. | ||
| # Ensure no warning for parameter mismatch is thrown. | ||
| # """ | ||
|
|
||
| # # TODO (kaushikb11): Add `parameter_validation` specific to TPU Accelerators | ||
| # class Model(WeightSharingModule): | ||
| @RunIf(tpu=True) | ||
| @pl_multi_process_test | ||
| def test_if_weights_tied(tmpdir, capsys=None): | ||
| """ | ||
| Test if weights are properly tied on `on_post_move_to_device`. | ||
| Ensure no warning for parameter mismatch is thrown. | ||
| """ | ||
|
|
||
| # def on_post_move_to_device(self): | ||
| # self.layer_3.weight = self.layer_1.weight | ||
| class Model(WeightSharingModule): | ||
|
|
||
| # model = Model() | ||
| # trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1) | ||
| def on_post_move_to_device(self): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we make this check slightly smarter but checking parameters names ? If I do self.layer_3.weight = self.layer_1.weight in the init function and mess up and do self.layer_3.weight = self.layer_2.weight, I won't get a warning but tying is different. Ideally it would be great to explicitly tell which weights are shared or do it automatically for the user.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting, will follow up. |
||
| self.layer_3.weight = self.layer_1.weight | ||
|
|
||
| # with pytest.warns(UserWarning) as warnings: | ||
| # trainer.fit(model) | ||
| model = Model() | ||
| trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1) | ||
|
|
||
| # assert not list(filter(lambda x: 'The model layers do not match' in str(x), warnings.list)) | ||
| # assert len(trainer.test(model)) == 1 | ||
| with pytest.warns(UserWarning, match="The model layers do not match"): | ||
| trainer.fit(model) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think now that you changed the decorator target to self.model, this decorator may no longer fit very well into core/decorators because it is basically now specific to the plugin having the attribute self.model.
What do you think about moving it?
Just for consideration
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to @awaelchli 's suggestion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. Will do a follow-up PR for this.